From 4aa2d8b2a2ffae77005e9095314348f93a330b23 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 4 Jun 2024 00:53:07 -0500 Subject: [PATCH] Purge shark/ directory, minimal ireert api usage for dynamically loaded plugins --- apps/shark_studio/api/utils.py | 198 ++-- apps/shark_studio/web/index.py | 6 +- shark/__init__.py | 28 - shark/backward_makefx.py | 78 -- shark/dynamo_backend/__init__.py | 0 shark/dynamo_backend/utils.py | 154 --- shark/examples/shark_dynamo/basic_examples.py | 25 - shark/examples/shark_eager/dynamo_demo.ipynb | 309 ------ shark/examples/shark_eager/dynamo_demo.py | 92 -- shark/examples/shark_eager/eager_mode.ipynb | 805 --------------- shark/examples/shark_eager/eager_mode.py | 148 --- .../shark_eager/squeezenet_lockstep.py | 73 -- .../examples/shark_inference/CLIPModel_tf.py | 65 -- .../examples/shark_inference/ESRGAN/README.md | 15 - .../examples/shark_inference/ESRGAN/esrgan.py | 239 ----- .../shark_inference/albert_maskfill_pt.py | 86 -- .../shark_inference/albert_maskfill_tf.py | 100 -- shark/examples/shark_inference/bloom_tank.py | 14 - shark/examples/shark_inference/gpt2_tf.py | 40 - .../examples/shark_inference/llama/README.md | 18 - shark/examples/shark_inference/mega_test.py | 72 -- .../examples/shark_inference/mhlo_example.py | 31 - .../shark_inference/minilm_benchmark.py | 35 - .../shark_inference/minilm_benchmark_tf.py | 61 -- shark/examples/shark_inference/minilm_jax.py | 73 -- .../minilm_jax_requirements.txt | 6 - shark/examples/shark_inference/minilm_jit.py | 23 - shark/examples/shark_inference/minilm_tf.py | 70 -- .../shark_inference/minilm_tf_gpu_config.json | 1 - shark/examples/shark_inference/resnest.py | 39 - .../examples/shark_inference/resnet50_fp16.py | 74 -- .../shark_inference/resnet50_script.py | 85 -- .../examples/shark_inference/sharded_bloom.py | 842 ---------------- .../sharded_bloom_large_models.py | 381 -------- shark/examples/shark_inference/simple_dlrm.py | 390 -------- shark/examples/shark_inference/sparse_arch.py | 311 ------ shark/examples/shark_inference/t5_tf.py | 35 - .../torch_vision_models_script.py | 43 - shark/examples/shark_inference/unet_script.py | 39 - .../examples/shark_inference/upscaler/main.py | 21 - .../upscaler/model_wrappers.py | 98 -- .../shark_inference/upscaler/opt_params.py | 48 - ...pipeline_shark_stable_diffusion_upscale.py | 489 ---------- .../shark_inference/upscaler/upscaler_args.py | 98 -- .../shark_inference/upscaler/utils.py | 230 ----- shark/examples/shark_inference/v_diffusion.py | 15 - .../examples/shark_training/bert_training.py | 48 - .../shark_training/bert_training_load_tf.py | 60 -- .../shark_training/bert_training_tf.py | 98 -- .../shark_training/neural_net_training.py | 44 - .../stable-diffusion-img2img/README.md | 41 - .../stable-diffusion-img2img/setup.sh | 25 - .../stable_diffusion_img2img.py | 600 ------------ .../shark_training/stable_diffusion/README.md | 43 - .../stable_diffusion_fine_tuning.py | 914 ------------------ shark/iree_eager_backend.py | 86 -- shark/iree_utils/__init__.py | 0 shark/iree_utils/_common.py | 164 ---- shark/iree_utils/benchmark_utils.py | 154 --- shark/iree_utils/compile_utils.py | 704 -------------- shark/iree_utils/cpu_utils.py | 65 -- shark/iree_utils/gpu_utils.py | 209 ---- shark/iree_utils/metal_utils.py | 102 -- shark/iree_utils/trace.py | 76 -- shark/iree_utils/vulkan_target_env_utils.py | 538 ----------- shark/iree_utils/vulkan_utils.py | 221 ----- shark/model_annotation.py | 468 --------- shark/parser.py | 170 ---- shark/shark_benchmark_runner.py | 501 ---------- shark/shark_compile.py | 241 ----- shark/shark_downloader.py | 297 ------ shark/shark_eager/shark_eager.py | 212 ---- shark/shark_generate_model_config.py | 153 --- shark/shark_importer.py | 819 ---------------- shark/shark_inference.py | 243 ----- shark/shark_runner.py | 127 --- shark/shark_trainer.py | 163 ---- shark/stress_test.py | 315 ------ shark/tests/test_shark_importer.py | 144 --- shark/tests/test_stress_test.py | 31 - shark/tests/test_txt2img_ui.py | 62 -- shark/tflite_utils.py | 208 ---- shark/torch_mlir_lockstep_tensor.py | 220 ----- shark/torch_mlir_utils.py | 90 -- 84 files changed, 75 insertions(+), 14684 deletions(-) delete mode 100644 shark/__init__.py delete mode 100644 shark/backward_makefx.py delete mode 100644 shark/dynamo_backend/__init__.py delete mode 100644 shark/dynamo_backend/utils.py delete mode 100644 shark/examples/shark_dynamo/basic_examples.py delete mode 100644 shark/examples/shark_eager/dynamo_demo.ipynb delete mode 100644 shark/examples/shark_eager/dynamo_demo.py delete mode 100644 shark/examples/shark_eager/eager_mode.ipynb delete mode 100644 shark/examples/shark_eager/eager_mode.py delete mode 100644 shark/examples/shark_eager/squeezenet_lockstep.py delete mode 100644 shark/examples/shark_inference/CLIPModel_tf.py delete mode 100644 shark/examples/shark_inference/ESRGAN/README.md delete mode 100644 shark/examples/shark_inference/ESRGAN/esrgan.py delete mode 100644 shark/examples/shark_inference/albert_maskfill_pt.py delete mode 100644 shark/examples/shark_inference/albert_maskfill_tf.py delete mode 100644 shark/examples/shark_inference/bloom_tank.py delete mode 100644 shark/examples/shark_inference/gpt2_tf.py delete mode 100644 shark/examples/shark_inference/llama/README.md delete mode 100644 shark/examples/shark_inference/mega_test.py delete mode 100644 shark/examples/shark_inference/mhlo_example.py delete mode 100644 shark/examples/shark_inference/minilm_benchmark.py delete mode 100644 shark/examples/shark_inference/minilm_benchmark_tf.py delete mode 100644 shark/examples/shark_inference/minilm_jax.py delete mode 100644 shark/examples/shark_inference/minilm_jax_requirements.txt delete mode 100644 shark/examples/shark_inference/minilm_jit.py delete mode 100644 shark/examples/shark_inference/minilm_tf.py delete mode 100644 shark/examples/shark_inference/minilm_tf_gpu_config.json delete mode 100644 shark/examples/shark_inference/resnest.py delete mode 100644 shark/examples/shark_inference/resnet50_fp16.py delete mode 100644 shark/examples/shark_inference/resnet50_script.py delete mode 100644 shark/examples/shark_inference/sharded_bloom.py delete mode 100644 shark/examples/shark_inference/sharded_bloom_large_models.py delete mode 100644 shark/examples/shark_inference/simple_dlrm.py delete mode 100644 shark/examples/shark_inference/sparse_arch.py delete mode 100644 shark/examples/shark_inference/t5_tf.py delete mode 100644 shark/examples/shark_inference/torch_vision_models_script.py delete mode 100644 shark/examples/shark_inference/unet_script.py delete mode 100644 shark/examples/shark_inference/upscaler/main.py delete mode 100644 shark/examples/shark_inference/upscaler/model_wrappers.py delete mode 100644 shark/examples/shark_inference/upscaler/opt_params.py delete mode 100644 shark/examples/shark_inference/upscaler/pipeline_shark_stable_diffusion_upscale.py delete mode 100644 shark/examples/shark_inference/upscaler/upscaler_args.py delete mode 100644 shark/examples/shark_inference/upscaler/utils.py delete mode 100644 shark/examples/shark_inference/v_diffusion.py delete mode 100644 shark/examples/shark_training/bert_training.py delete mode 100644 shark/examples/shark_training/bert_training_load_tf.py delete mode 100644 shark/examples/shark_training/bert_training_tf.py delete mode 100644 shark/examples/shark_training/neural_net_training.py delete mode 100644 shark/examples/shark_training/stable-diffusion-img2img/README.md delete mode 100644 shark/examples/shark_training/stable-diffusion-img2img/setup.sh delete mode 100644 shark/examples/shark_training/stable-diffusion-img2img/stable_diffusion_img2img.py delete mode 100644 shark/examples/shark_training/stable_diffusion/README.md delete mode 100644 shark/examples/shark_training/stable_diffusion/stable_diffusion_fine_tuning.py delete mode 100644 shark/iree_eager_backend.py delete mode 100644 shark/iree_utils/__init__.py delete mode 100644 shark/iree_utils/_common.py delete mode 100644 shark/iree_utils/benchmark_utils.py delete mode 100644 shark/iree_utils/compile_utils.py delete mode 100644 shark/iree_utils/cpu_utils.py delete mode 100644 shark/iree_utils/gpu_utils.py delete mode 100644 shark/iree_utils/metal_utils.py delete mode 100644 shark/iree_utils/trace.py delete mode 100644 shark/iree_utils/vulkan_target_env_utils.py delete mode 100644 shark/iree_utils/vulkan_utils.py delete mode 100644 shark/model_annotation.py delete mode 100644 shark/parser.py delete mode 100644 shark/shark_benchmark_runner.py delete mode 100644 shark/shark_compile.py delete mode 100644 shark/shark_downloader.py delete mode 100644 shark/shark_eager/shark_eager.py delete mode 100644 shark/shark_generate_model_config.py delete mode 100644 shark/shark_importer.py delete mode 100644 shark/shark_inference.py delete mode 100644 shark/shark_runner.py delete mode 100644 shark/shark_trainer.py delete mode 100644 shark/stress_test.py delete mode 100644 shark/tests/test_shark_importer.py delete mode 100644 shark/tests/test_stress_test.py delete mode 100644 shark/tests/test_txt2img_ui.py delete mode 100644 shark/tflite_utils.py delete mode 100644 shark/torch_mlir_lockstep_tensor.py delete mode 100644 shark/torch_mlir_utils.py diff --git a/apps/shark_studio/api/utils.py b/apps/shark_studio/api/utils.py index efbc205f03..8eebb97966 100644 --- a/apps/shark_studio/api/utils.py +++ b/apps/shark_studio/api/utils.py @@ -12,12 +12,60 @@ from cpuinfo import get_cpu_info +def iree_device_map(device): + uri_parts = device.split("://", 2) + iree_driver = ( + _IREE_DEVICE_MAP[uri_parts[0]] + if uri_parts[0] in _IREE_DEVICE_MAP + else uri_parts[0] + ) + if len(uri_parts) == 1: + return iree_driver + elif "rocm" in uri_parts: + return "rocm" + else: + return f"{iree_driver}://{uri_parts[1]}" + + +def get_supported_device_list(): + return list(_IREE_DEVICE_MAP.keys()) + + +_IREE_DEVICE_MAP = { + "cpu": "local-task", + "cpu-task": "local-task", + "cpu-sync": "local-sync", + "cuda": "cuda", + "vulkan": "vulkan", + "metal": "metal", + "rocm": "rocm", + "hip": "hip", + "intel-gpu": "level_zero", +} + + +def iree_target_map(device): + if "://" in device: + device = device.split("://")[0] + return _IREE_TARGET_MAP[device] if device in _IREE_TARGET_MAP else device + + +_IREE_TARGET_MAP = { + "cpu": "llvm-cpu", + "cpu-task": "llvm-cpu", + "cpu-sync": "llvm-cpu", + "cuda": "cuda", + "vulkan": "vulkan-spirv", + "metal": "metal", + "rocm": "rocm", + "hip": "rocm", + "intel-gpu": "opencl-spirv", +} + def get_available_devices(): - return ["AMD Radeon 780M => rocm"] def get_devices_by_name(driver_name): - from shark.iree_utils._common import iree_device_map device_list = [] try: @@ -91,13 +139,29 @@ def get_devices_by_name(driver_name): break return available_devices +def clean_device_info(raw_device): + # return appropriate device and device_id for consumption by Studio pipeline + # Multiple devices only supported for vulkan and rocm (as of now). + # default device must be selected for all others -def parse_device(device_str, target_override=""): - from shark.iree_utils.compile_utils import ( - clean_device_info, - get_iree_target_triple, - iree_target_map, + device_id = None + device = ( + raw_device + if "=>" not in raw_device + else raw_device.split("=>")[1].strip() ) + if "://" in device: + device, device_id = device.split("://") + if len(device_id) <= 2: + device_id = int(device_id) + + if device not in ["hip", "rocm", "vulkan"]: + device_id = None + if device in ["hip", "rocm", "vulkan"] and device_id == None: + device_id = 0 + return device, device_id + +def parse_device(device_str, target_override=""): rt_driver, device_id = clean_device_info(device_str) target_backend = iree_target_map(rt_driver) @@ -144,9 +208,6 @@ def get_rocm_target_chip(device_str): if key in device_str: return rocm_chip_map[key] return None - # raise AssertionError( - # f"Device {device_str} not recognized. Please file an issue at https://github.com/nod-ai/SHARK/issues." - # ) def get_all_devices(driver_name): @@ -179,7 +240,6 @@ def get_device_mapping(driver, key_combination=3): dict: map to possible device names user can input mapped to desired combination of name/path. """ - from shark.iree_utils._common import iree_device_map driver = iree_device_map(driver) device_list = get_all_devices(driver) @@ -226,118 +286,4 @@ def get_opt_flags(model, precision="fp16"): # Due to lack of support for multi-reduce, we always collapse reduction # dims before dispatch formation right now. iree_flags += ["--iree-flow-collapse-reduction-dims"] - return iree_flags - - -# def map_device_to_name_path(device, key_combination=3): -# """Gives the appropriate device data (supported name/path) for user -# selected execution device -# Args: -# device (str): user -# key_combination (int, optional): choice for mapping value for -# device name. -# 1 : path -# 2 : name -# 3 : (name, path) -# Defaults to 3. -# Raises: -# ValueError: -# Returns: -# str / tuple: returns the mapping str or tuple of mapping str for -# the device depending on key_combination value -# """ -# driver = device.split("://")[0] -# device_map = get_device_mapping(driver, key_combination) -# try: -# device_mapping = device_map[device] -# except KeyError: -# raise ValueError(f"Device '{device}' is not a valid device.") -# return device_mapping - -# def get_devices_by_name(driver_name): -# from shark.iree_utils._common import iree_device_map - -# device_list = [] -# try: -# driver_name = iree_device_map(driver_name) -# device_list_dict = get_all_devices(driver_name) -# print(f"{driver_name} devices are available.") -# except: -# print(f"{driver_name} devices are not available.") -# else: -# cpu_name = get_cpu_info()["brand_raw"] -# for i, device in enumerate(device_list_dict): -# device_name = ( -# cpu_name if device["name"] == "default" else device["name"] -# ) -# if "local" in driver_name: -# device_list.append( -# f"{device_name} => {driver_name.replace('local', 'cpu')}" -# ) -# else: -# # for drivers with single devices -# # let the default device be selected without any indexing -# if len(device_list_dict) == 1: -# device_list.append(f"{device_name} => {driver_name}") -# else: -# device_list.append(f"{device_name} => {driver_name}://{i}") -# return device_list - -# set_iree_runtime_flags() - -# available_devices = [] -# from shark.iree_utils.vulkan_utils import ( -# get_all_vulkan_devices, -# ) - -# vulkaninfo_list = get_all_vulkan_devices() -# vulkan_devices = [] -# id = 0 -# for device in vulkaninfo_list: -# vulkan_devices.append(f"{device.strip()} => vulkan://{id}") -# id += 1 -# if id != 0: -# print(f"vulkan devices are available.") -# available_devices.extend(vulkan_devices) -# metal_devices = get_devices_by_name("metal") -# available_devices.extend(metal_devices) -# cuda_devices = get_devices_by_name("cuda") -# available_devices.extend(cuda_devices) -# rocm_devices = get_devices_by_name("rocm") -# available_devices.extend(rocm_devices) -# cpu_device = get_devices_by_name("cpu-sync") -# available_devices.extend(cpu_device) -# cpu_device = get_devices_by_name("cpu-task") -# available_devices.extend(cpu_device) -# return available_devices - - -# # Generate and return a new seed if the provided one is not in the -# # supported range (including -1) -# def sanitize_seed(seed: int | str): -# seed = int(seed) -# uint32_info = np.iinfo(np.uint32) -# uint32_min, uint32_max = uint32_info.min, uint32_info.max -# if seed < uint32_min or seed >= uint32_max: -# seed = randint(uint32_min, uint32_max) -# return seed - - -# # take a seed expression in an input format and convert it to -# # a list of integers, where possible -# def parse_seed_input(seed_input: str | list | int): -# if isinstance(seed_input, str): -# try: -# seed_input = json.loads(seed_input) -# except (ValueError, TypeError): -# seed_input = None - -# if isinstance(seed_input, int): -# return [seed_input] - -# if isinstance(seed_input, list) and all(type(seed) is int for seed in seed_input): -# return seed_input - -# raise TypeError( -# "Seed input must be an integer or an array of integers in JSON format" -# ) + return iree_flags \ No newline at end of file diff --git a/apps/shark_studio/web/index.py b/apps/shark_studio/web/index.py index f4f2539479..f32bda123b 100644 --- a/apps/shark_studio/web/index.py +++ b/apps/shark_studio/web/index.py @@ -83,7 +83,7 @@ def webui(): launch_api = cmd_opts.api initialize.initialize() - from ui.chat import chat_element + #from ui.chat import chat_element from ui.sd import sd_element from ui.outputgallery import outputgallery_element @@ -194,8 +194,8 @@ def register_outputgallery_button(button, selectedid, inputs, outputs): sd_element.render() with gr.TabItem(label="Output Gallery", id=1): outputgallery_element.render() - with gr.TabItem(label="Chat Bot", id=2, visible=False): - chat_element.render() + # with gr.TabItem(label="Chat Bot", id=2): + # chat_element.render() studio_web.queue() diff --git a/shark/__init__.py b/shark/__init__.py deleted file mode 100644 index e5b86b2828..0000000000 --- a/shark/__init__.py +++ /dev/null @@ -1,28 +0,0 @@ -import importlib -import logging - -from torch._dynamo import register_backend - -log = logging.getLogger(__name__) - - -@register_backend -def shark(model, inputs, *, options): - try: - from shark.dynamo_backend.utils import SharkBackend - except ImportError: - log.exception( - "Unable to import SHARK - High Performance Machine Learning Distribution" - "Please install the right version of SHARK that matches the PyTorch version being used. " - "Refer to https://github.com/nod-ai/SHARK/ for details." - ) - raise - return SharkBackend(model, inputs, options) - - -def has_shark(): - try: - importlib.import_module("shark") - return True - except ImportError: - return False diff --git a/shark/backward_makefx.py b/shark/backward_makefx.py deleted file mode 100644 index a6e70e6577..0000000000 --- a/shark/backward_makefx.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2020 The Nod Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from torch._decomp import get_decompositions -from torch.fx.experimental.proxy_tensor import make_fx -from torch.nn.utils import stateless - -from torch import fx -import tempfile - - -class MakeFxModule: - def __init__(self, model, inputs, labels=None, custom_inference_fn=None): - self.model = model - self.inputs = inputs - self.custom_inference_fn = custom_inference_fn - self.training_graph = None - - # Doesn't replace the None type. - def change_fx_graph_return_to_tuple(self, fx_g: fx.GraphModule): - for node in fx_g.graph.nodes: - if node.op == "output": - # output nodes always have one argument - node_arg = node.args[0] - out_nodes = [] - if isinstance(node_arg, list): - # Don't return NoneType elements. - for out_node in node_arg: - if not isinstance(out_node, type(None)): - out_nodes.append(out_node) - # If there is a single tensor/element to be returned don't - # a tuple for it. - if len(out_nodes) == 1: - node.args = out_nodes - else: - node.args = (tuple(out_nodes),) - fx_g.graph.lint() - fx_g.recompile() - return fx_g - - def generate_graph(self): - fx_g = make_fx( - self.custom_inference_fn, - decomposition_table=get_decompositions( - [ - torch.ops.aten.embedding_dense_backward, - torch.ops.aten.native_layer_norm_backward, - torch.ops.aten.slice_backward, - torch.ops.aten.select_backward, - ] - ), - )( - dict(self.model.named_parameters()), - dict(self.model.named_buffers()), - self.inputs, - ) - fx_g.graph.set_codegen(torch.fx.graph.CodeGen()) - fx_g.recompile() - fx_g = self.change_fx_graph_return_to_tuple(fx_g) - ts_g = torch.jit.script(fx_g) - temp = tempfile.NamedTemporaryFile( - suffix="_shark_ts", prefix="temp_ts_" - ) - ts_g.save(temp.name) - new_ts = torch.jit.load(temp.name) - self.training_graph = new_ts diff --git a/shark/dynamo_backend/__init__.py b/shark/dynamo_backend/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/shark/dynamo_backend/utils.py b/shark/dynamo_backend/utils.py deleted file mode 100644 index 90808d630d..0000000000 --- a/shark/dynamo_backend/utils.py +++ /dev/null @@ -1,154 +0,0 @@ -import functools -from typing import List, Optional -import torch -from torch.fx.experimental.proxy_tensor import make_fx -from torch._functorch.compile_utils import strip_overloads -from shark.shark_inference import SharkInference -from torch._decomp import get_decompositions -from torch.func import functionalize -import io -import torch_mlir - - -# TODO: Control decompositions. -def default_decompositions(): - return get_decompositions( - [ - torch.ops.aten.embedding_dense_backward, - torch.ops.aten.native_layer_norm_backward, - torch.ops.aten.slice_backward, - torch.ops.aten.select_backward, - torch.ops.aten.norm.ScalarOpt_dim, - torch.ops.aten.native_group_norm, - torch.ops.aten.upsample_bilinear2d.vec, - torch.ops.aten.split.Tensor, - torch.ops.aten.split_with_sizes, - torch.ops.aten.native_layer_norm, - torch.ops.aten.masked_fill.Tensor, - torch.ops.aten.masked_fill.Scalar, - ] - ) - - -def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]: - removed_indexes = [] - for node in fx_g.graph.nodes: - if node.op == "output": - assert ( - len(node.args) == 1 - ), "Output node must have a single argument" - node_arg = node.args[0] - if isinstance(node_arg, (list, tuple)): - node_arg = list(node_arg) - node_args_len = len(node_arg) - for i in range(node_args_len): - curr_index = node_args_len - (i + 1) - if node_arg[curr_index] is None: - removed_indexes.append(curr_index) - node_arg.pop(curr_index) - node.args = (tuple(node_arg),) - break - - if len(removed_indexes) > 0: - fx_g.graph.lint() - fx_g.graph.eliminate_dead_code() - fx_g.recompile() - removed_indexes.sort() - return removed_indexes - - -def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool: - for node in fx_g.graph.nodes: - if node.op == "output": - assert ( - len(node.args) == 1 - ), "Output node must have a single argument" - node_arg = node.args[0] - if isinstance(node_arg, tuple): - return len(node_arg) == 0 - return False - - -def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool: - """ - Replace tuple with tuple element in functions that return one-element tuples. - Returns true if an unwrapping took place, and false otherwise. - """ - unwrapped_tuple = False - for node in fx_g.graph.nodes: - if node.op == "output": - assert ( - len(node.args) == 1 - ), "Output node must have a single argument" - node_arg = node.args[0] - if isinstance(node_arg, tuple): - if len(node_arg) == 1: - node.args = (node_arg[0],) - unwrapped_tuple = True - break - - if unwrapped_tuple: - fx_g.graph.lint() - fx_g.recompile() - return unwrapped_tuple - - -class SharkBackend: - def __init__( - self, fx_g: torch.fx.GraphModule, inputs: tuple, options: dict - ): - self.fx_g = fx_g - self.inputs = inputs - self.shark_module = None - self.device: str = options.get("device", "cpu") - self.was_unwrapped: bool = False - self.none_indices: list = [] - self._modify_fx_g() - self.compile() - - def _modify_fx_g(self): - self.none_indices = _remove_nones(self.fx_g) - self.was_unwrapped = _unwrap_single_tuple_return(self.fx_g) - - def compile(self): - gm = make_fx( - functionalize(self.fx_g), - decomposition_table=default_decompositions(), - )(*self.inputs) - gm.graph.set_codegen(torch.fx.graph.CodeGen()) - gm.recompile() - strip_overloads(gm) - ts_g = torch.jit.script(gm) - mlir_module = torch_mlir.compile( - ts_g, self.inputs, output_type="linalg-on-tensors" - ) - bytecode_stream = io.BytesIO() - mlir_module.operation.write_bytecode(bytecode_stream) - bytecode = bytecode_stream.getvalue() - from shark.shark_inference import SharkInference - - shark_module = SharkInference( - mlir_module=bytecode, - device=self.device, - mlir_dialect="tm_tensor", - ) - shark_module.compile(extra_args=[]) - self.shark_module = shark_module - - def __call__(self, *inputs): - np_inputs = [x.contiguous().detach().cpu().numpy() for x in inputs] - np_outs = self.shark_module("forward", np_inputs) - if self.was_unwrapped: - np_outs = [ - np_outs, - ] - - if not isinstance(np_outs, list): - res = torch.from_numpy(np_outs) - return res - - result = [torch.from_numpy(x) for x in np_outs] - for r_in in self.none_indices: - result.insert(r_in, None) - result = tuple(result) - return result diff --git a/shark/examples/shark_dynamo/basic_examples.py b/shark/examples/shark_dynamo/basic_examples.py deleted file mode 100644 index 16b21fa570..0000000000 --- a/shark/examples/shark_dynamo/basic_examples.py +++ /dev/null @@ -1,25 +0,0 @@ -import torch -import shark - - -def foo(x, a): - if x.shape[0] > 3: - return x + a - else: - return x + 3 - - -shark_options = {"device": "cpu"} -compiled = torch.compile(foo, backend="shark", options=shark_options) - -input = torch.ones(4) - -x = compiled(input, input) - -print(x) - -input = torch.ones(3) - -x = compiled(input, input) - -print(x) diff --git a/shark/examples/shark_eager/dynamo_demo.ipynb b/shark/examples/shark_eager/dynamo_demo.ipynb deleted file mode 100644 index 526ff95b28..0000000000 --- a/shark/examples/shark_eager/dynamo_demo.ipynb +++ /dev/null @@ -1,309 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "collapsed": true, - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/mlevental/miniconda3/envs/torch-mlir/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], - "source": [ - "# standard imports\n", - "import torch\n", - "from shark.iree_utils import get_iree_compiled_module" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "outputs": [], - "source": [ - "# torch dynamo related imports\n", - "try:\n", - " import torchdynamo\n", - " from torchdynamo.optimizations.backends import create_backend\n", - " from torchdynamo.optimizations.subgraph import SubGraph\n", - "except ModuleNotFoundError:\n", - " print(\n", - " \"Please install TorchDynamo using pip install git+https://github.com/pytorch/torchdynamo\"\n", - " )\n", - " exit()\n", - "\n", - "# torch-mlir imports for compiling\n", - "from torch_mlir import compile, OutputType" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } - }, - { - "cell_type": "markdown", - "source": [ - "[TorchDynamo](https://github.com/pytorch/torchdynamo) is a compiler for PyTorch programs that uses the [frame evaluation API](https://www.python.org/dev/peps/pep-0523/) in CPython to dynamically modify Python bytecode right before it is executed. It creates this FX Graph through bytecode analysis and is designed to mix Python execution with compiled backends." - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } - } - }, - { - "cell_type": "code", - "execution_count": 3, - "outputs": [], - "source": [ - "def toy_example(*args):\n", - " a, b = args\n", - "\n", - " x = a / (torch.abs(a) + 1)\n", - " if b.sum() < 0:\n", - " b = b * -1\n", - " return x * b" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } - }, - { - "cell_type": "code", - "execution_count": 4, - "outputs": [], - "source": [ - "# compiler that lowers fx_graph to through MLIR\n", - "def __torch_mlir(fx_graph, *args, **kwargs):\n", - " assert isinstance(\n", - " fx_graph, torch.fx.GraphModule\n", - " ), \"Model must be an FX GraphModule.\"\n", - "\n", - " def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule):\n", - " \"\"\"Replace tuple with tuple element in functions that return one-element tuples.\"\"\"\n", - "\n", - " for node in fx_g.graph.nodes:\n", - " if node.op == \"output\":\n", - " assert (\n", - " len(node.args) == 1\n", - " ), \"Output node must have a single argument\"\n", - " node_arg = node.args[0]\n", - " if isinstance(node_arg, tuple) and len(node_arg) == 1:\n", - " node.args = (node_arg[0],)\n", - " fx_g.graph.lint()\n", - " fx_g.recompile()\n", - " return fx_g\n", - "\n", - " fx_graph = _unwrap_single_tuple_return(fx_graph)\n", - " ts_graph = torch.jit.script(fx_graph)\n", - "\n", - " # torchdynamo does munges the args differently depending on whether you use\n", - " # the @torchdynamo.optimize decorator or the context manager\n", - " if isinstance(args, tuple):\n", - " args = list(args)\n", - " assert isinstance(args, list)\n", - " if len(args) == 1 and isinstance(args[0], list):\n", - " args = args[0]\n", - "\n", - " linalg_module = compile(\n", - " ts_graph, args, output_type=OutputType.LINALG_ON_TENSORS\n", - " )\n", - " callable, _ = get_iree_compiled_module(\n", - " linalg_module, \"cuda\", func_name=\"forward\"\n", - " )\n", - "\n", - " def forward(*inputs):\n", - " return callable(*inputs)\n", - "\n", - " return forward" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } - }, - { - "cell_type": "markdown", - "source": [ - "Simplest way to use TorchDynamo with the `torchdynamo.optimize` context manager:" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } - } - }, - { - "cell_type": "code", - "execution_count": 5, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Found 1 device(s).\n", - "Device: 0\n", - " Name: NVIDIA GeForce RTX 3080\n", - " Compute Capability: 8.6\n", - "[-0.40066046 -0.4210303 0.03225489 -0.44849953 0.10370405 -0.04422468\n", - " 0.33262825 -0.20109026 0.02102537 -0.24882983]\n", - "[-0.07824923 -0.17004533 0.06439921 -0.06163602 0.26633525 -1.1560082\n", - " -0.06660341 0.24227881 0.1462235 -0.32055548]\n", - "[-0.01464001 0.442209 -0.0607936 -0.5477967 -0.25226554 -0.08588809\n", - " -0.30497575 0.00061084 -0.50069696 0.2317973 ]\n", - "[ 0.25726247 0.39388427 -0.24093066 0.12316308 -0.01981307 0.5661146\n", - " 0.26199922 0.8123446 -0.01576749 0.30846444]\n", - "[ 0.7878203 -0.45975062 -0.29956317 -0.07032048 -0.55817443 -0.62506855\n", - " -1.6837492 -0.38442805 0.28220773 -1.5325156 ]\n", - "[ 0.07975311 0.67754704 -0.30927914 0.00347631 -0.07326564 0.01893554\n", - " -0.7518105 -0.03078967 -0.07623022 0.38865626]\n", - "[-0.7751679 -0.5841397 -0.6622711 0.18574935 -0.6049372 0.02844244\n", - " -0.20471913 0.3337415 -0.3619432 -0.35087156]\n", - "[-0.08569919 -0.10775139 -0.02338934 0.21933547 -0.46712473 0.00062137\n", - " -0.58207744 0.06457533 0.18276742 0.03866556]\n", - "[-0.2311981 -0.43036282 0.20561649 -0.10363232 -0.13248594 0.02885137\n", - " -0.31241602 -0.36907142 0.08861586 0.2331427 ]\n", - "[-0.07273526 -0.31246194 -0.24218291 -0.24145737 0.0364486 0.14382267\n", - " -0.00531162 0.15447603 -0.5220248 -0.09016377]\n" - ] - } - ], - "source": [ - "with torchdynamo.optimize(__torch_mlir):\n", - " for _ in range(10):\n", - " print(toy_example(torch.randn(10), torch.randn(10)))" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } - }, - { - "cell_type": "markdown", - "source": [ - "It can also be used through a decorator:" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } - } - }, - { - "cell_type": "code", - "execution_count": 6, - "outputs": [], - "source": [ - "@create_backend\n", - "def torch_mlir(subgraph, *args, **kwargs):\n", - " assert isinstance(subgraph, SubGraph), \"Model must be a dynamo SubGraph.\"\n", - " return __torch_mlir(subgraph.model, *list(subgraph.example_inputs))\n", - "\n", - "\n", - "@torchdynamo.optimize(\"torch_mlir\")\n", - "def toy_example2(*args):\n", - " a, b = args\n", - "\n", - " x = a / (torch.abs(a) + 1)\n", - " if b.sum() < 0:\n", - " b = b * -1\n", - " return x * b" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } - }, - { - "cell_type": "code", - "execution_count": 7, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Found 1 device(s).\n", - "Device: 0\n", - " Name: NVIDIA GeForce RTX 3080\n", - " Compute Capability: 8.6\n", - "[-0.35494277 0.03409214 -0.02271946 0.7335942 0.03122527 -0.41881397\n", - " -0.6609761 -0.6418614 0.29336175 -0.01973678]\n", - "[-2.7246824e-01 -3.5543957e-01 6.0087401e-01 -7.4570496e-03\n", - " -4.2481605e-02 -5.0296803e-04 7.2928613e-01 -1.4673788e-03\n", - " -2.7621329e-01 -6.0995776e-02]\n", - "[-0.03165906 0.3889693 0.24052973 0.27279532 -0.02773128 -0.12602475\n", - " -1.0124422 0.5720256 -0.35437614 -0.20992722]\n", - "[-0.41831446 0.5525326 -0.29749998 -0.17044766 0.11804754 -0.05210691\n", - " -0.46145165 -0.8776549 0.10090438 0.17463352]\n", - "[ 0.02194221 0.20959911 0.26973712 0.12551276 -0.0020404 0.1490246\n", - " -0.04456685 1.1100804 0.8105744 0.6676846 ]\n", - "[ 0.06528181 -0.13591261 0.5370964 -0.4398162 -0.03372452 0.9691372\n", - " -0.01120087 0.2947028 0.4804801 -0.3324341 ]\n", - "[ 0.33549032 -0.23001772 -0.08681437 0.16490957 -0.11223086 0.09168988\n", - " 0.02403045 0.17344482 0.46406478 -0.00129451]\n", - "[-0.27475086 0.42384806 1.9090122 -0.41147137 -0.6888369 0.08435658\n", - " -0.26628923 -0.17436793 -0.8058869 -0.02582378]\n", - "[-0.10109414 0.08681287 -0.10055986 0.6858881 0.29267687 -0.02797117\n", - " -0.01425194 0.4882803 0.3551982 -0.858935 ]\n", - "[-0.22086617 0.524994 0.17721705 -0.03813264 -0.54570735 -0.4421502\n", - " 0.11938014 -0.01122053 0.39294165 -0.61770755]\n" - ] - } - ], - "source": [ - "for _ in range(10):\n", - " print(toy_example2(torch.randn(10), torch.randn(10)))" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 2 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file diff --git a/shark/examples/shark_eager/dynamo_demo.py b/shark/examples/shark_eager/dynamo_demo.py deleted file mode 100644 index a4cf9ca958..0000000000 --- a/shark/examples/shark_eager/dynamo_demo.py +++ /dev/null @@ -1,92 +0,0 @@ -import torch -from torch_mlir import compile, OutputType - -from shark.iree_utils import get_iree_compiled_module - -try: - import torchdynamo - from torchdynamo.optimizations.backends import create_backend - from torchdynamo.optimizations.subgraph import SubGraph -except ModuleNotFoundError: - print( - "Please install TorchDynamo using pip install git+https://github.com/pytorch/torchdynamo" - ) - exit() - -NUM_ITERS = 10 - - -def __torch_mlir(fx_graph, *args, **kwargs): - assert isinstance( - fx_graph, torch.fx.GraphModule - ), "Model must be an FX GraphModule." - - def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule): - """Replace tuple with tuple element in functions that return one-element tuples.""" - - for node in fx_g.graph.nodes: - if node.op == "output": - assert ( - len(node.args) == 1 - ), "Output node must have a single argument" - node_arg = node.args[0] - if isinstance(node_arg, tuple) and len(node_arg) == 1: - node.args = (node_arg[0],) - fx_g.graph.lint() - fx_g.recompile() - return fx_g - - fx_graph = _unwrap_single_tuple_return(fx_graph) - ts_graph = torch.jit.script(fx_graph) - - if isinstance(args, tuple): - args = list(args) - assert isinstance(args, list) - if len(args) == 1 and isinstance(args[0], list): - args = args[0] - - linalg_module = compile( - ts_graph, args, output_type=OutputType.LINALG_ON_TENSORS - ) - callable, _ = get_iree_compiled_module( - linalg_module, "cuda", func_name="forward" - ) - - def forward(*inputs): - return callable(*inputs) - - return forward - - -def toy_example(*args): - a, b = args - - x = a / (torch.abs(a) + 1) - if b.sum() < 0: - b = b * -1 - return x * b - - -with torchdynamo.optimize(__torch_mlir): - for _ in range(10): - print(toy_example(torch.randn(10), torch.randn(10))) - - -@create_backend -def torch_mlir(subgraph, *args, **kwargs): - assert isinstance(subgraph, SubGraph), "Model must be a dynamo SubGraph." - return __torch_mlir(subgraph.model, *list(subgraph.example_inputs)) - - -@torchdynamo.optimize("torch_mlir") -def toy_example2(*args): - a, b = args - - x = a / (torch.abs(a) + 1) - if b.sum() < 0: - b = b * -1 - return x * b - - -for _ in range(10): - print(toy_example2(torch.randn(10), torch.randn(10))) diff --git a/shark/examples/shark_eager/eager_mode.ipynb b/shark/examples/shark_eager/eager_mode.ipynb deleted file mode 100644 index 7d162e9221..0000000000 --- a/shark/examples/shark_eager/eager_mode.ipynb +++ /dev/null @@ -1,805 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/mlevental/miniconda3/envs/torch-mlir/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], - "source": [ - "# standard imports\n", - "import torch\n", - "from torch_mlir.eager_mode import torch_mlir_tensor" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } - }, - { - "cell_type": "code", - "execution_count": 2, - "outputs": [], - "source": [ - "# eager mode imports\n", - "from torch_mlir.eager_mode.torch_mlir_tensor import TorchMLIRTensor\n", - "from shark.iree_eager_backend import EagerModeIREELinalgOnTensorsBackend" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } - }, - { - "cell_type": "markdown", - "source": [ - "The simplest way of using Eager Mode (through IREE) requires setting a \"backend\":" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } - } - }, - { - "cell_type": "code", - "execution_count": 3, - "outputs": [], - "source": [ - "torch_mlir_tensor.backend = EagerModeIREELinalgOnTensorsBackend(\"cpu\")" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } - }, - { - "cell_type": "markdown", - "source": [ - "and wrapping all your `torch.Tensor`s:" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } - } - }, - { - "cell_type": "code", - "execution_count": 4, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "TorchMLIRTensor(, backend=EagerModeIREELinalgOnTensorsBackend)\n", - "TorchMLIRTensor(, backend=EagerModeIREELinalgOnTensorsBackend)\n" - ] - } - ], - "source": [ - "NUM_ITERS = 10\n", - "\n", - "t = torch.ones((10, 10))\n", - "u = 2 * torch.ones((10, 10))\n", - "\n", - "tt = TorchMLIRTensor(t)\n", - "print(tt)\n", - "uu = TorchMLIRTensor(u)\n", - "print(uu)" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } - }, - { - "cell_type": "markdown", - "source": [ - "`TorchMLIRTensor` is a \"tensor wrapper subclass\" (more info [here](https://github.com/albanD/subclass_zoo)) that keeps the IREE `DeviceArray` in a field `elem`:" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } - } - }, - { - "cell_type": "code", - "execution_count": 5, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n", - "\n", - "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n", - "\n", - "[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n", - "\n", - "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n", - "\n", - "[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n", - "\n", - "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n", - "\n", - "[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n", - "\n", - "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n", - "\n", - "[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n", - "\n", - "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n", - "\n", - "[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n", - "\n", - "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n", - "\n", - "[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n", - "\n", - "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n", - "\n", - "[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n", - "\n", - "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n", - "\n", - "[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n", - "\n", - "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n", - "\n", - "[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n", - "\n", - "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n" - ] - } - ], - "source": [ - "for i in range(NUM_ITERS):\n", - " yy = tt + uu\n", - " print(type(yy))\n", - " print(yy.elem.to_host())\n", - " yy = tt * uu\n", - " print(type(yy))\n", - " print(yy.elem.to_host())" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } - }, - { - "cell_type": "markdown", - "source": [ - "If you have a GPU (and CUDA installed) that works too (you can verify by having `watch -n1 nvidia-smi` up in a terminal while running the next cell):" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } - } - }, - { - "cell_type": "code", - "execution_count": 6, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "TorchMLIRTensor(, backend=EagerModeIREELinalgOnTensorsBackend)\n", - "TorchMLIRTensor(, backend=EagerModeIREELinalgOnTensorsBackend)\n", - "[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n", - " [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n", - "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n" - ] - } - ], - "source": [ - "torch_mlir_tensor.backend = EagerModeIREELinalgOnTensorsBackend(\"gpu\")\n", - "\n", - "t = torch.ones((10, 10))\n", - "u = 2 * torch.ones((10, 10))\n", - "\n", - "tt = TorchMLIRTensor(t)\n", - "print(tt)\n", - "uu = TorchMLIRTensor(u)\n", - "print(uu)\n", - "\n", - "yy = tt + uu\n", - "print(yy.elem.to_host())\n", - "yy = tt * uu\n", - "print(yy.elem.to_host())" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } - }, - { - "cell_type": "markdown", - "source": [ - "There is a convenience class `SharkEagerMode` that will handle both the installation of the backend and the wrapping of `torch.Tensor`s:" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } - } - }, - { - "cell_type": "code", - "execution_count": 7, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "TorchMLIRTensor(, backend=EagerModeIREELinalgOnTensorsBackend)\n", - "TorchMLIRTensor(, backend=EagerModeIREELinalgOnTensorsBackend)\n", - "\n", - "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n", - "\n", - "[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n", - "\n", - "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n", - "\n", - "[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n", - "\n", - "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n", - "\n", - "[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n", - "\n", - "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n", - "\n", - "[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n", - "\n", - "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n", - "\n", - "[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n", - "\n", - "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n", - "\n", - "[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n", - "\n", - "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n", - "\n", - "[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n", - "\n", - "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n", - "\n", - "[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n", - "\n", - "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n", - "\n", - "[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n", - "\n", - "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n", - "\n", - "[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n" - ] - } - ], - "source": [ - "# eager mode RAII\n", - "from shark.shark_runner import SharkEagerMode\n", - "\n", - "shark_eager_mode = SharkEagerMode(\"cpu\")\n", - "\n", - "t = torch.ones((10, 10))\n", - "u = torch.ones((10, 10))\n", - "\n", - "print(t)\n", - "print(u)\n", - "\n", - "for i in range(NUM_ITERS):\n", - " yy = t + u\n", - " print(type(yy))\n", - " print(yy.elem.to_host())\n", - " yy = t * u\n", - " print(type(yy))\n", - " print(yy.elem.to_host())" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } - }, - { - "cell_type": "markdown", - "source": [ - "The `SharkEagerMode` class is a hacky take on [RAII](https://en.wikipedia.org/wiki/Resource_acquisition_is_initialization) that defines a \"deleter\" that runs when an instantiation (of `SharkEagerMode`) is garbage collected. Takeaway is that if you want to turn off `SharkEagerMode`, or switch backends, you need to `del` the instance:" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } - } - }, - { - "cell_type": "code", - "execution_count": 8, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "TorchMLIRTensor(, backend=EagerModeIREELinalgOnTensorsBackend)\n", - "TorchMLIRTensor(, backend=EagerModeIREELinalgOnTensorsBackend)\n", - "\n", - "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", - " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n", - "\n", - "[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n", - " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n" - ] - } - ], - "source": [ - "del shark_eager_mode\n", - "shark_eager_mode = SharkEagerMode(\"cuda\")\n", - "\n", - "t = torch.ones((10, 10))\n", - "u = torch.ones((10, 10))\n", - "\n", - "print(t)\n", - "print(u)\n", - "\n", - "yy = t + u\n", - "print(type(yy))\n", - "print(yy.elem.to_host())\n", - "yy = t * u\n", - "print(type(yy))\n", - "print(yy.elem.to_host())" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 2 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file diff --git a/shark/examples/shark_eager/eager_mode.py b/shark/examples/shark_eager/eager_mode.py deleted file mode 100644 index a440cc7e8e..0000000000 --- a/shark/examples/shark_eager/eager_mode.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright 2020 The Nod Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from torch.utils.cpp_extension import load_inline, include_paths -from torch_mlir.eager_mode import torch_mlir_tensor -from torch_mlir.eager_mode.torch_mlir_tensor import TorchMLIRTensor - -from shark.iree_eager_backend import EagerModeIREELinalgOnTensorsBackend -from shark.shark_runner import SharkEagerMode - - -def test_cpu(): - torch_mlir_tensor.backend = EagerModeIREELinalgOnTensorsBackend("cpu") - - t = torch.ones((10, 10), device="cpu") - u = 2 * torch.ones((10, 10), device="cpu") - - tt = TorchMLIRTensor(t) - print(tt) - uu = TorchMLIRTensor(u) - print(uu) - - for i in range(NUM_ITERS): - yy = tt + uu - print(type(yy)) - print(yy.elem.to_host()) - yy = tt * uu - print(type(yy)) - print(yy.elem.to_host()) - - -def test_gpu(): - source = """ - #include - #include "cuda.h" - #include "cuda_runtime_api.h" - - using namespace std; - - void print_free_mem() { - int num_gpus; - size_t free, total; - cudaSetDevice(0); - int id; - cudaGetDevice(&id); - cudaMemGetInfo(&free, &total); - cout << "GPU " << id << " memory: used=" << (total-free)/(1<<20) << endl; - } - """ - gpu_stats = load_inline( - name="inline_extension", - cpp_sources=[source], - extra_include_paths=include_paths(cuda=True), - functions=["print_free_mem"], - ) - torch_mlir_tensor.backend = EagerModeIREELinalgOnTensorsBackend("gpu") - - t = torch.ones((10, 10), device="cpu") - u = 2 * torch.ones((10, 10), device="cpu") - - tt = TorchMLIRTensor(t) - print(tt) - uu = TorchMLIRTensor(u) - print(uu) - - for i in range(NUM_ITERS): - yy = tt + uu - print(yy.elem.to_host()) - yy = tt * uu - print(yy.elem.to_host()) - gpu_stats.print_free_mem() - - -def test_python_mode_ref_backend(): - # hide this wherever you want? - _ = SharkEagerMode("refbackend") - - t = torch.ones((10, 10), device="cpu") - u = torch.ones((10, 10), device="cpu") - - print(t) - print(u) - - for i in range(NUM_ITERS): - print(i) - yy = t + u - print(yy.elem) - yy = t * u - print(yy.elem) - - -def test_python_mode_iree_cpu(): - # hide this wherever you want? - _ = SharkEagerMode("cpu") - - t = torch.ones((10, 10), device="cpu") - u = torch.ones((10, 10), device="cpu") - - print(t) - print(u) - - for i in range(NUM_ITERS): - yy = t + u - print(type(yy)) - print(yy.elem.to_host()) - yy = t * u - print(type(yy)) - print(yy.elem.to_host()) - - -def test_python_mode_iree_gpu(): - _ = SharkEagerMode("gpu") - - t = torch.ones((10, 10), device="cpu") - u = torch.ones((10, 10), device="cpu") - - print(t) - print(u) - - for i in range(NUM_ITERS): - yy = t + u - print(type(yy)) - print(yy.elem.to_host()) - yy = t * u - print(type(yy)) - print(yy.elem.to_host()) - - -if __name__ == "__main__": - NUM_ITERS = 10 - test_cpu() - if torch.cuda.is_available(): - test_gpu() - test_python_mode_ref_backend() - test_python_mode_iree_cpu() - test_python_mode_iree_gpu() diff --git a/shark/examples/shark_eager/squeezenet_lockstep.py b/shark/examples/shark_eager/squeezenet_lockstep.py deleted file mode 100644 index f5fcf42d33..0000000000 --- a/shark/examples/shark_eager/squeezenet_lockstep.py +++ /dev/null @@ -1,73 +0,0 @@ -import torch -import numpy as np - -model = torch.hub.load( - "pytorch/vision:v0.10.0", "squeezenet1_0", pretrained=True -) -model.eval() - -# from PIL import Image -# from torchvision import transforms -# import urllib -# -# url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg") -# try: urllib.URLopener().retrieve(url, filename) -# except: urllib.request.urlretrieve(url, filename) -# -# -# input_image = Image.open(filename) -# preprocess = transforms.Compose([ -# transforms.Resize(256), -# transforms.CenterCrop(224), -# transforms.ToTensor(), -# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), -# ]) -# input_tensor = preprocess(input_image) -# input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model -# print(input_batch.shape) # size = [1, 3, 224, 224] - -# The above is code for generating sample inputs from an image. We can just use -# random values for accuracy testing though -input_batch = torch.randn(1, 3, 224, 224) - - -# Focus on CPU for now -if False and torch.cuda.is_available(): - input_batch = input_batch.to("cuda") - model.to("cuda") - -with torch.no_grad(): - output = model(input_batch) -# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes -golden_confidences = output[0] -# The output has unnormalized scores. To get probabilities, you can run a softmax on it. -golden_probabilities = torch.nn.functional.softmax( - golden_confidences, dim=0 -).numpy() - -golden_confidences = golden_confidences.numpy() - -from shark.torch_mlir_lockstep_tensor import TorchMLIRLockstepTensor - -input_detached_clone = input_batch.clone() -eager_input_batch = TorchMLIRLockstepTensor(input_detached_clone) - -print("getting torch-mlir result") - -output = model(eager_input_batch) - -static_output = output.elem -confidences = static_output[0] -probabilities = torch.nn.functional.softmax( - torch.from_numpy(confidences), dim=0 -).numpy() - -print("The obtained result via shark is: ", confidences) -print("The golden result is:", golden_confidences) - -np.testing.assert_allclose( - golden_confidences, confidences, rtol=1e-02, atol=1e-03 -) -np.testing.assert_allclose( - golden_probabilities, probabilities, rtol=1e-02, atol=1e-03 -) diff --git a/shark/examples/shark_inference/CLIPModel_tf.py b/shark/examples/shark_inference/CLIPModel_tf.py deleted file mode 100644 index 78d909498e..0000000000 --- a/shark/examples/shark_inference/CLIPModel_tf.py +++ /dev/null @@ -1,65 +0,0 @@ -from PIL import Image -import requests - -from transformers import CLIPProcessor, TFCLIPModel -import tensorflow as tf -from shark.shark_inference import SharkInference - -# Create a set of inputs -clip_vit_inputs = [ - tf.TensorSpec(shape=[2, 7], dtype=tf.int32), - tf.TensorSpec(shape=[2, 7], dtype=tf.int32), - tf.TensorSpec(shape=[1, 3, 224, 224], dtype=tf.float32), -] - - -class CLIPModule(tf.Module): - def __init__(self): - super(CLIPModule, self).__init__() - self.m = TFCLIPModel.from_pretrained("openai/clip-vit-base-patch32") - - self.m.predict = lambda x, y, z: self.m( - input_ids=x, attention_mask=y, pixel_values=z - ) - - @tf.function(input_signature=clip_vit_inputs, jit_compile=True) - def forward(self, input_ids, attention_mask, pixel_values): - return self.m.predict( - input_ids, attention_mask, pixel_values - ).logits_per_image - - -if __name__ == "__main__": - # Prepping Data - processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") - - url = "http://images.cocodataset.org/val2017/000000039769.jpg" - image = Image.open(requests.get(url, stream=True).raw) - - inputs = processor( - text=["a photo of a cat", "a photo of a dog"], - images=image, - return_tensors="tf", - padding=True, - ) - - shark_module = SharkInference( - CLIPModule(), - ( - inputs["input_ids"], - inputs["attention_mask"], - inputs["pixel_values"], - ), - ) - shark_module.set_frontend("tensorflow") - shark_module.compile() - - print( - shark_module.forward( - ( - inputs["input_ids"], - inputs["attention_mask"], - inputs["pixel_values"], - ) - ) - ) diff --git a/shark/examples/shark_inference/ESRGAN/README.md b/shark/examples/shark_inference/ESRGAN/README.md deleted file mode 100644 index 60b90646cb..0000000000 --- a/shark/examples/shark_inference/ESRGAN/README.md +++ /dev/null @@ -1,15 +0,0 @@ -## Running ESRGAN - -``` -1. pip install numpy opencv-python -2. mkdir InputImages - (this is where all the input images will reside in) -3. mkdir OutputImages - (this is where the model will generate all the images) -4. mkdir models - (save the .pth checkpoint file here) -5. python esrgan.py -``` - -- Download [RRDB_ESRGAN_x4.pth](https://drive.google.com/drive/u/0/folders/17VYV_SoZZesU6mbxz2dMAIccSSlqLecY) and place it in the `models` directory as mentioned above in step 4. -- Credits : [ESRGAN](https://github.com/xinntao/ESRGAN) diff --git a/shark/examples/shark_inference/ESRGAN/esrgan.py b/shark/examples/shark_inference/ESRGAN/esrgan.py deleted file mode 100644 index 426f4e93c9..0000000000 --- a/shark/examples/shark_inference/ESRGAN/esrgan.py +++ /dev/null @@ -1,239 +0,0 @@ -from ast import arg -import os.path as osp -import glob -import cv2 -import numpy as np -import torch - -from torch.fx.experimental.proxy_tensor import make_fx -from torch._decomp import get_decompositions -from shark.shark_inference import SharkInference -import torch_mlir -import tempfile -import functools -import torch -import torch.nn as nn -import torch.nn.functional as F - - -def make_layer(block, n_layers): - layers = [] - for _ in range(n_layers): - layers.append(block()) - return nn.Sequential(*layers) - - -class ResidualDenseBlock_5C(nn.Module): - def __init__(self, nf=64, gc=32, bias=True): - super(ResidualDenseBlock_5C, self).__init__() - # gc: growth channel, i.e. intermediate channels - self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) - self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) - self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) - self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) - self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - - # initialization - # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) - - def forward(self, x): - x1 = self.lrelu(self.conv1(x)) - x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) - x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) - x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) - x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) - return x5 * 0.2 + x - - -class RRDB(nn.Module): - """Residual in Residual Dense Block""" - - def __init__(self, nf, gc=32): - super(RRDB, self).__init__() - self.RDB1 = ResidualDenseBlock_5C(nf, gc) - self.RDB2 = ResidualDenseBlock_5C(nf, gc) - self.RDB3 = ResidualDenseBlock_5C(nf, gc) - - def forward(self, x): - out = self.RDB1(x) - out = self.RDB2(out) - out = self.RDB3(out) - return out * 0.2 + x - - -class RRDBNet(nn.Module): - def __init__(self, in_nc, out_nc, nf, nb, gc=32): - super(RRDBNet, self).__init__() - RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) - - self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) - self.RRDB_trunk = make_layer(RRDB_block_f, nb) - self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - #### upsampling - self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) - - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - - def forward(self, x): - fea = self.conv_first(x) - trunk = self.trunk_conv(self.RRDB_trunk(fea)) - fea = fea + trunk - - fea = self.lrelu( - self.upconv1(F.interpolate(fea, scale_factor=2, mode="nearest")) - ) - fea = self.lrelu( - self.upconv2(F.interpolate(fea, scale_factor=2, mode="nearest")) - ) - out = self.conv_last(self.lrelu(self.HRconv(fea))) - - return out - - -############### Parsing args ##################### -import argparse - -p = argparse.ArgumentParser( - description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter -) - -p.add_argument("--device", type=str, default="cpu", help="the device to use") -p.add_argument( - "--mlir_loc", - type=str, - default=None, - help="location of the model's mlir file", -) -args = p.parse_args() -################################################### - - -def inference(input_m): - return model(input_m) - - -def load_mlir(mlir_loc): - import os - - if mlir_loc == None: - return None - print(f"Trying to load the model from {mlir_loc}.") - with open(os.path.join(mlir_loc)) as f: - mlir_module = f.read() - return mlir_module - - -def compile_through_fx(model, inputs, mlir_loc=None): - module = load_mlir(mlir_loc) - if module == None: - fx_g = make_fx( - model, - decomposition_table=get_decompositions( - [ - torch.ops.aten.embedding_dense_backward, - torch.ops.aten.native_layer_norm_backward, - torch.ops.aten.slice_backward, - torch.ops.aten.select_backward, - torch.ops.aten.norm.ScalarOpt_dim, - torch.ops.aten.native_group_norm, - torch.ops.aten.upsample_bilinear2d.vec, - torch.ops.aten.split.Tensor, - torch.ops.aten.split_with_sizes, - ] - ), - )(inputs) - - fx_g.graph.set_codegen(torch.fx.graph.CodeGen()) - fx_g.recompile() - - def strip_overloads(gm): - """ - Modifies the target of graph nodes in :attr:`gm` to strip overloads. - Args: - gm(fx.GraphModule): The input Fx graph module to be modified - """ - for node in gm.graph.nodes: - if isinstance(node.target, torch._ops.OpOverload): - node.target = node.target.overloadpacket - gm.recompile() - - strip_overloads(fx_g) - - ts_g = torch.jit.script(fx_g) - - print("Torchscript graph generated successfully") - module = torch_mlir.compile( - ts_g, - inputs, - torch_mlir.OutputType.LINALG_ON_TENSORS, - use_tracing=False, - verbose=False, - ) - - mlir_model = str(module) - func_name = "forward" - shark_module = SharkInference( - mlir_model, device=args.device, mlir_dialect="linalg" - ) - shark_module.compile() - - return shark_module - - -model_path = "models/RRDB_ESRGAN_x4.pth" # models/RRDB_ESRGAN_x4.pth OR models/RRDB_PSNR_x4.pth -# device = torch.device('cuda') # if you want to run on CPU, change 'cuda' -> cpu -device = torch.device("cpu") - -test_img_folder = "InputImages/*" - -model = RRDBNet(3, 3, 64, 23, gc=32) -model.load_state_dict(torch.load(model_path), strict=True) -model.eval() -model = model.to(device) - -print("Model path {:s}. \nTesting...".format(model_path)) - -if __name__ == "__main__": - idx = 0 - for path in glob.glob(test_img_folder): - idx += 1 - base = osp.splitext(osp.basename(path))[0] - print(idx, base) - # read images - img = cv2.imread(path, cv2.IMREAD_COLOR) - img = img * 1.0 / 255 - img = torch.from_numpy( - np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1)) - ).float() - img_LR = img.unsqueeze(0) - img_LR = img_LR.to(device) - - with torch.no_grad(): - shark_module = compile_through_fx(inference, img_LR) - shark_output = shark_module.forward((img_LR,)) - shark_output = torch.from_numpy(shark_output) - shark_output = ( - shark_output.data.squeeze().float().cpu().clamp_(0, 1).numpy() - ) - esrgan_output = ( - model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy() - ) - # SHARK OUTPUT - shark_output = np.transpose(shark_output[[2, 1, 0], :, :], (1, 2, 0)) - shark_output = (shark_output * 255.0).round() - cv2.imwrite( - "OutputImages/{:s}_rlt_shark_output.png".format(base), shark_output - ) - print("Generated SHARK's output") - # ESRGAN OUTPUT - esrgan_output = np.transpose(esrgan_output[[2, 1, 0], :, :], (1, 2, 0)) - esrgan_output = (esrgan_output * 255.0).round() - cv2.imwrite( - "OutputImages/{:s}_rlt_esrgan_output.png".format(base), - esrgan_output, - ) - print("Generated ESRGAN's output") diff --git a/shark/examples/shark_inference/albert_maskfill_pt.py b/shark/examples/shark_inference/albert_maskfill_pt.py deleted file mode 100644 index 2c5d11c362..0000000000 --- a/shark/examples/shark_inference/albert_maskfill_pt.py +++ /dev/null @@ -1,86 +0,0 @@ -from transformers import AutoModelForMaskedLM, AutoTokenizer -import torch -from shark.shark_inference import SharkInference -from shark.shark_importer import SharkImporter -from iree.compiler import compile_str -from iree import runtime as ireert -import os -import numpy as np - -MAX_SEQUENCE_LENGTH = 512 -BATCH_SIZE = 1 - - -class AlbertModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.model = AutoModelForMaskedLM.from_pretrained("albert-base-v2") - self.model.eval() - - def forward(self, input_ids, attention_mask): - return self.model( - input_ids=input_ids, attention_mask=attention_mask - ).logits - - -if __name__ == "__main__": - # Prepping Data - tokenizer = AutoTokenizer.from_pretrained("albert-base-v2") - text = "This [MASK] is very tasty." - encoded_inputs = tokenizer( - text, - padding="max_length", - truncation=True, - max_length=MAX_SEQUENCE_LENGTH, - return_tensors="pt", - ) - inputs = (encoded_inputs["input_ids"], encoded_inputs["attention_mask"]) - mlir_importer = SharkImporter( - AlbertModule(), - inputs, - frontend="torch", - ) - minilm_mlir, func_name = mlir_importer.import_mlir( - is_dynamic=False, tracing_required=True - ) - shark_module = SharkInference(minilm_mlir) - shark_module.compile() - token_logits = torch.tensor(shark_module.forward(inputs)) - mask_id = torch.where( - encoded_inputs["input_ids"] == tokenizer.mask_token_id - )[1] - mask_token_logits = token_logits[0, mask_id, :] - top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist() - for token in top_5_tokens: - print( - f"'>>> Sample/Warmup output: {text.replace(tokenizer.mask_token, tokenizer.decode(token))}'" - ) - while True: - try: - new_text = input("Give me a sentence with [MASK] to fill: ") - encoded_inputs = tokenizer( - new_text, - padding="max_length", - truncation=True, - max_length=MAX_SEQUENCE_LENGTH, - return_tensors="pt", - ) - inputs = ( - encoded_inputs["input_ids"], - encoded_inputs["attention_mask"], - ) - token_logits = torch.tensor(shark_module.forward(inputs)) - mask_id = torch.where( - encoded_inputs["input_ids"] == tokenizer.mask_token_id - )[1] - mask_token_logits = token_logits[0, mask_id, :] - top_5_tokens = ( - torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist() - ) - for token in top_5_tokens: - print( - f"'>>> {new_text.replace(tokenizer.mask_token, tokenizer.decode(token))}'" - ) - except KeyboardInterrupt: - print("Exiting program.") - break diff --git a/shark/examples/shark_inference/albert_maskfill_tf.py b/shark/examples/shark_inference/albert_maskfill_tf.py deleted file mode 100644 index a7927ecdfa..0000000000 --- a/shark/examples/shark_inference/albert_maskfill_tf.py +++ /dev/null @@ -1,100 +0,0 @@ -from PIL import Image -import requests - -from transformers import TFAutoModelForMaskedLM, AutoTokenizer -import tensorflow as tf -from shark.shark_inference import SharkInference -from shark.shark_importer import SharkImporter -from iree.compiler import tf as tfc -from iree.compiler import compile_str -from iree import runtime as ireert -import os -import numpy as np -import sys - -MAX_SEQUENCE_LENGTH = 512 -BATCH_SIZE = 1 - -# Create a set of inputs -t5_inputs = [ - tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32), - tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32), -] - - -class AlbertModule(tf.Module): - def __init__(self): - super(AlbertModule, self).__init__() - self.m = TFAutoModelForMaskedLM.from_pretrained("albert-base-v2") - self.m.predict = lambda x, y: self.m(input_ids=x, attention_mask=y) - - @tf.function(input_signature=t5_inputs, jit_compile=True) - def forward(self, input_ids, attention_mask): - return self.m.predict(input_ids, attention_mask) - - -if __name__ == "__main__": - # Prepping Data - tokenizer = AutoTokenizer.from_pretrained("albert-base-v2") - # text = "This is a great [MASK]." - text = "This [MASK] is very tasty." - encoded_inputs = tokenizer( - text, - padding="max_length", - truncation=True, - max_length=MAX_SEQUENCE_LENGTH, - return_tensors="tf", - ) - inputs = (encoded_inputs["input_ids"], encoded_inputs["attention_mask"]) - mlir_importer = SharkImporter( - AlbertModule(), - inputs, - frontend="tf", - ) - minilm_mlir, func_name = mlir_importer.import_mlir( - is_dynamic=False, tracing_required=False - ) - shark_module = SharkInference(minilm_mlir, mlir_dialect="mhlo") - shark_module.compile() - output_idx = 0 - data_idx = 1 - token_logits = shark_module.forward(inputs)[output_idx][data_idx] - mask_id = np.where( - tf.squeeze(encoded_inputs["input_ids"]) == tokenizer.mask_token_id - ) - mask_token_logits = token_logits[0, mask_id, :] - top_5_tokens = np.flip(np.argsort(mask_token_logits)).squeeze()[0:5] - for token in top_5_tokens: - print( - f"'>>> Sample/Warmup output: {text.replace(tokenizer.mask_token, tokenizer.decode(token))}'" - ) - while True: - try: - new_text = input("Give me a sentence with [MASK] to fill: ") - encoded_inputs = tokenizer( - new_text, - padding="max_length", - truncation=True, - max_length=MAX_SEQUENCE_LENGTH, - return_tensors="tf", - ) - inputs = ( - encoded_inputs["input_ids"], - encoded_inputs["attention_mask"], - ) - token_logits = shark_module.forward(inputs)[output_idx][data_idx] - mask_id = np.where( - tf.squeeze(encoded_inputs["input_ids"]) - == tokenizer.mask_token_id - ) - mask_token_logits = token_logits[0, mask_id, :] - top_5_tokens = np.flip(np.argsort(mask_token_logits)).squeeze()[ - 0:5 - ] - for token in top_5_tokens: - print( - f"'>>> {new_text.replace(tokenizer.mask_token, tokenizer.decode(token))}'" - ) - except KeyboardInterrupt: - print("Exiting program.") - sys.exit() diff --git a/shark/examples/shark_inference/bloom_tank.py b/shark/examples/shark_inference/bloom_tank.py deleted file mode 100644 index 25f67f6ceb..0000000000 --- a/shark/examples/shark_inference/bloom_tank.py +++ /dev/null @@ -1,14 +0,0 @@ -from shark.shark_inference import SharkInference -from shark.shark_downloader import download_model - -mlir_model, func_name, inputs, golden_out = download_model( - "bloom", frontend="torch" -) - -shark_module = SharkInference( - mlir_model, device="cpu", mlir_dialect="tm_tensor" -) -shark_module.compile() -result = shark_module.forward(inputs) -print("The obtained result via shark is: ", result) -print("The golden result is:", golden_out) diff --git a/shark/examples/shark_inference/gpt2_tf.py b/shark/examples/shark_inference/gpt2_tf.py deleted file mode 100644 index 98b402fc88..0000000000 --- a/shark/examples/shark_inference/gpt2_tf.py +++ /dev/null @@ -1,40 +0,0 @@ -from PIL import Image -import requests - -from transformers import GPT2Tokenizer, TFGPT2Model -import tensorflow as tf -from shark.shark_inference import SharkInference - -# Create a set of inputs -gpt2_inputs = [ - tf.TensorSpec(shape=[1, 8], dtype=tf.int32), - tf.TensorSpec(shape=[1, 8], dtype=tf.int32), -] - - -class GPT2Module(tf.Module): - def __init__(self): - super(GPT2Module, self).__init__() - self.m = TFGPT2Model.from_pretrained("distilgpt2") - - self.m.predict = lambda x, y: self.m(input_ids=x, attention_mask=y) - - @tf.function(input_signature=gpt2_inputs, jit_compile=True) - def forward(self, input_ids, attention_mask): - return self.m.predict(input_ids, attention_mask) - - -if __name__ == "__main__": - # Prepping Data - tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2") - text = "I love the distilled version of models." - - inputs = tokenizer(text, return_tensors="tf") - shark_module = SharkInference( - GPT2Module(), (inputs["input_ids"], inputs["attention_mask"]) - ) - shark_module.set_frontend("tensorflow") - shark_module.compile() - print( - shark_module.forward((inputs["input_ids"], inputs["attention_mask"])) - ) diff --git a/shark/examples/shark_inference/llama/README.md b/shark/examples/shark_inference/llama/README.md deleted file mode 100644 index e6ca34895b..0000000000 --- a/shark/examples/shark_inference/llama/README.md +++ /dev/null @@ -1,18 +0,0 @@ -# SHARK LLaMA - -## TORCH-MLIR Version - -``` -https://github.com/nod-ai/torch-mlir.git -``` -Then check out the `complex` branch and `git submodule update --init` and then build with `.\build_tools\python_deploy\build_windows.ps1` - -### Setup & Run -``` -git clone https://github.com/nod-ai/llama.git -``` -Then in this repository -``` -pip install -e . -python llama/shark_model.py -``` diff --git a/shark/examples/shark_inference/mega_test.py b/shark/examples/shark_inference/mega_test.py deleted file mode 100644 index a4e6f6b406..0000000000 --- a/shark/examples/shark_inference/mega_test.py +++ /dev/null @@ -1,72 +0,0 @@ -import torch -import torch_mlir -from shark.shark_inference import SharkInference -from shark.shark_compile import shark_compile_through_fx -from MEGABYTE_pytorch import MEGABYTE - -import os - - -class MegaModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.model = MEGABYTE( - num_tokens=16000, # number of tokens - dim=( - 512, - 256, - ), # transformer model dimension (512 for coarsest, 256 for fine in this example) - max_seq_len=( - 1024, - 4, - ), # sequence length for global and then local. this can be more than 2 - depth=( - 6, - 4, - ), # number of layers for global and then local. this can be more than 2, but length must match the max_seq_len's - dim_head=64, # dimension per head - heads=8, # number of attention heads - flash_attn=True, # use flash attention - ) - - def forward(self, input): - return self.model(input) - - -megaModel = MegaModel() -inputs = [torch.randint(0, 16000, (1, 1024, 4))] - -# CURRENTLY IT BAILS OUT HERE BECAUSE OF MISSING OP LOWERINGS :- -# 1. aten.alias -shark_module, _ = shark_compile_through_fx( - model=megaModel, - inputs=inputs, - extended_model_name="mega_shark", - is_f16=False, - f16_input_mask=None, - save_dir=os.getcwd(), - debug=False, - generate_or_load_vmfb=True, - extra_args=[], - device="cuda", - mlir_dialect="tm_tensor", -) -# logits = model(x) - - -def print_output_info(output, msg): - print("\n", msg) - print("\n\t", output.shape) - - -ans = shark_module("forward", inputs) -print_output_info(torch.from_numpy(ans), "SHARK's output") - -ans = megaModel.forward(*inputs) -print_output_info(ans, "ORIGINAL Model's output") - -# and sample from the logits accordingly -# or you can use the generate function - -# NEED TO LOOK AT THIS LATER IF REQUIRED IN SHARK. -# sampled = model.generate(temperature = 0.9, filter_thres = 0.9) # (1, 1024, 4) diff --git a/shark/examples/shark_inference/mhlo_example.py b/shark/examples/shark_inference/mhlo_example.py deleted file mode 100644 index b271403231..0000000000 --- a/shark/examples/shark_inference/mhlo_example.py +++ /dev/null @@ -1,31 +0,0 @@ -from shark.shark_inference import SharkInference -import numpy as np - -mhlo_ir = r"""builtin.module { - func.func @forward(%arg0: tensor<1x4xf32>, %arg1: tensor<4x1xf32>) -> tensor<4x4xf32> { - %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<4x4xf32> - %1 = "mhlo.abs"(%0) : (tensor<4x4xf32>) -> tensor<4x4xf32> - return %1 : tensor<4x4xf32> - } -}""" - -arg0 = np.ones((1, 4)).astype(np.float32) -arg1 = np.ones((4, 1)).astype(np.float32) - -print("Running shark on cpu backend") -shark_module = SharkInference(mhlo_ir, device="cpu", mlir_dialect="mhlo") - -# Generate the random inputs and feed into the graph. -x = shark_module.generate_random_inputs() -shark_module.compile() -print(shark_module.forward(x)) - -print("Running shark on cuda backend") -shark_module = SharkInference(mhlo_ir, device="cuda", mlir_dialect="mhlo") -shark_module.compile() -print(shark_module.forward(x)) - -print("Running shark on vulkan backend") -shark_module = SharkInference(mhlo_ir, device="vulkan", mlir_dialect="mhlo") -shark_module.compile() -print(shark_module.forward(x)) diff --git a/shark/examples/shark_inference/minilm_benchmark.py b/shark/examples/shark_inference/minilm_benchmark.py deleted file mode 100644 index 3263f84d18..0000000000 --- a/shark/examples/shark_inference/minilm_benchmark.py +++ /dev/null @@ -1,35 +0,0 @@ -import torch -from transformers import AutoTokenizer, AutoModelForSequenceClassification -from shark.shark_inference import SharkInference - -torch.manual_seed(0) -tokenizer = AutoTokenizer.from_pretrained("microsoft/MiniLM-L12-H384-uncased") - - -class MiniLMSequenceClassification(torch.nn.Module): - def __init__(self): - super().__init__() - self.model = AutoModelForSequenceClassification.from_pretrained( - "microsoft/MiniLM-L12-H384-uncased", # The pretrained model. - num_labels=2, # The number of output labels--2 for binary classification. - output_attentions=False, # Whether the model returns attentions weights. - output_hidden_states=False, # Whether the model returns all hidden-states. - torchscript=True, - ) - - def forward(self, tokens): - return self.model.forward(tokens)[0] - - -test_input = torch.randint(2, (1, 128)) - -shark_module = SharkInference( - MiniLMSequenceClassification(), - (test_input,), - jit_trace=True, - benchmark_mode=True, -) - -shark_module.compile() -shark_module.forward((test_input,)) -shark_module.benchmark_all((test_input,)) diff --git a/shark/examples/shark_inference/minilm_benchmark_tf.py b/shark/examples/shark_inference/minilm_benchmark_tf.py deleted file mode 100644 index 1c2858e817..0000000000 --- a/shark/examples/shark_inference/minilm_benchmark_tf.py +++ /dev/null @@ -1,61 +0,0 @@ -import tensorflow as tf -from transformers import BertModel, BertTokenizer, TFBertModel -from shark.shark_inference import SharkInference - -MAX_SEQUENCE_LENGTH = 512 -BATCH_SIZE = 1 - -# Create a set of 2-dimensional inputs -bert_input = [ - tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32), - tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32), - tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32), -] - - -class BertModule(tf.Module): - def __init__(self): - super(BertModule, self).__init__() - # Create a BERT trainer with the created network. - self.m = TFBertModel.from_pretrained( - "microsoft/MiniLM-L12-H384-uncased", from_pt=True - ) - - # Invoke the trainer model on the inputs. This causes the layer to be built. - self.m.predict = lambda x, y, z: self.m.call( - input_ids=x, attention_mask=y, token_type_ids=z, training=False - ) - - @tf.function(input_signature=bert_input, jit_compile=True) - def forward(self, input_ids, attention_mask, token_type_ids): - return self.m.predict(input_ids, attention_mask, token_type_ids) - - -if __name__ == "__main__": - # Prepping Data - tokenizer = BertTokenizer.from_pretrained( - "microsoft/MiniLM-L12-H384-uncased" - ) - text = "Replace me by any text you'd like." - encoded_input = tokenizer( - text, - padding="max_length", - truncation=True, - max_length=MAX_SEQUENCE_LENGTH, - ) - for key in encoded_input: - encoded_input[key] = tf.expand_dims( - tf.convert_to_tensor(encoded_input[key]), 0 - ) - - test_input = ( - encoded_input["input_ids"], - encoded_input["attention_mask"], - encoded_input["token_type_ids"], - ) - shark_module = SharkInference( - BertModule(), test_input, benchmark_mode=True - ) - shark_module.set_frontend("tensorflow") - shark_module.compile() - shark_module.benchmark_all(test_input) diff --git a/shark/examples/shark_inference/minilm_jax.py b/shark/examples/shark_inference/minilm_jax.py deleted file mode 100644 index 6207c6f270..0000000000 --- a/shark/examples/shark_inference/minilm_jax.py +++ /dev/null @@ -1,73 +0,0 @@ -from transformers import AutoTokenizer, FlaxAutoModel -import torch -import jax -from typing import Union, Dict, List, Any -import numpy as np -from shark.shark_inference import SharkInference -import io - -NumpyTree = Union[np.ndarray, Dict[str, np.ndarray], List[np.ndarray]] - - -def convert_torch_tensor_tree_to_numpy( - tree: Union[torch.tensor, Dict[str, torch.tensor], List[torch.tensor]] -) -> NumpyTree: - return jax.tree_util.tree_map( - lambda torch_tensor: torch_tensor.cpu().detach().numpy(), tree - ) - - -def convert_int64_to_int32(tree: NumpyTree) -> NumpyTree: - return jax.tree_util.tree_map( - lambda tensor: np.array(tensor, dtype=np.int32) - if tensor.dtype == np.int64 - else tensor, - tree, - ) - - -def get_sample_input(): - tokenizer = AutoTokenizer.from_pretrained( - "microsoft/MiniLM-L12-H384-uncased" - ) - inputs_torch = tokenizer("Hello, World!", return_tensors="pt") - return convert_int64_to_int32( - convert_torch_tensor_tree_to_numpy(inputs_torch.data) - ) - - -def get_jax_model(): - return FlaxAutoModel.from_pretrained("microsoft/MiniLM-L12-H384-uncased") - - -def export_jax_to_mlir(jax_model: Any, sample_input: NumpyTree): - model_mlir = jax.jit(jax_model).lower(**sample_input).compiler_ir() - byte_stream = io.BytesIO() - model_mlir.operation.write_bytecode(file=byte_stream) - return byte_stream.getvalue() - - -def assert_array_list_allclose(x, y, *args, **kwargs): - assert len(x) == len(y) - for a, b in zip(x, y): - np.testing.assert_allclose( - np.asarray(a), np.asarray(b), *args, **kwargs - ) - - -sample_input = get_sample_input() -jax_model = get_jax_model() -mlir = export_jax_to_mlir(jax_model, sample_input) - -# Compile and load module. -shark_inference = SharkInference(mlir_module=mlir, mlir_dialect="mhlo") -shark_inference.compile() - -# Run main function. -result = shark_inference("main", jax.tree_util.tree_flatten(sample_input)[0]) - -# Run JAX model. -reference_result = jax.tree_util.tree_flatten(jax_model(**sample_input))[0] - -# Verify result. -assert_array_list_allclose(result, reference_result, atol=1e-5) diff --git a/shark/examples/shark_inference/minilm_jax_requirements.txt b/shark/examples/shark_inference/minilm_jax_requirements.txt deleted file mode 100644 index 1f92025a49..0000000000 --- a/shark/examples/shark_inference/minilm_jax_requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -flax -jax[cpu] -nodai-SHARK -orbax -transformers -torch diff --git a/shark/examples/shark_inference/minilm_jit.py b/shark/examples/shark_inference/minilm_jit.py deleted file mode 100644 index df9aa98b7f..0000000000 --- a/shark/examples/shark_inference/minilm_jit.py +++ /dev/null @@ -1,23 +0,0 @@ -from shark.shark_inference import SharkInference -from shark.shark_downloader import download_model - - -mlir_model, func_name, inputs, golden_out = download_model( - "microsoft/MiniLM-L12-H384-uncased", - frontend="torch", -) - - -shark_module = SharkInference(mlir_model, device="cpu", mlir_dialect="linalg") -shark_module.compile() -result = shark_module.forward(inputs) -print("The obtained result via shark is: ", result) -print("The golden result is:", golden_out) - - -# Let's generate random inputs, currently supported -# for static models. -rand_inputs = shark_module.generate_random_inputs() -rand_results = shark_module.forward(rand_inputs) - -print("Running shark_module with random_inputs is: ", rand_results) diff --git a/shark/examples/shark_inference/minilm_tf.py b/shark/examples/shark_inference/minilm_tf.py deleted file mode 100644 index 3ac0d7bff8..0000000000 --- a/shark/examples/shark_inference/minilm_tf.py +++ /dev/null @@ -1,70 +0,0 @@ -import tensorflow as tf -from transformers import BertModel, BertTokenizer, TFBertModel -from shark.shark_inference import SharkInference - -MAX_SEQUENCE_LENGTH = 512 -BATCH_SIZE = 1 - -# Create a set of 2-dimensional inputs -bert_input = [ - tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32), - tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32), - tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32), -] - - -class BertModule(tf.Module): - def __init__(self): - super(BertModule, self).__init__() - # Create a BERT trainer with the created network. - self.m = TFBertModel.from_pretrained( - "microsoft/MiniLM-L12-H384-uncased", from_pt=True - ) - - # Invoke the trainer model on the inputs. This causes the layer to be built. - self.m.predict = lambda x, y, z: self.m.call( - input_ids=x, attention_mask=y, token_type_ids=z, training=False - ) - - @tf.function(input_signature=bert_input, jit_compile=True) - def forward(self, input_ids, attention_mask, token_type_ids): - return self.m.predict(input_ids, attention_mask, token_type_ids) - - -if __name__ == "__main__": - # Prepping Data - tokenizer = BertTokenizer.from_pretrained( - "microsoft/MiniLM-L12-H384-uncased" - ) - text = "Replace me by any text you'd like." - encoded_input = tokenizer( - text, - padding="max_length", - truncation=True, - max_length=MAX_SEQUENCE_LENGTH, - ) - for key in encoded_input: - encoded_input[key] = tf.expand_dims( - tf.convert_to_tensor(encoded_input[key]), 0 - ) - - shark_module = SharkInference( - BertModule(), - ( - encoded_input["input_ids"], - encoded_input["attention_mask"], - encoded_input["token_type_ids"], - ), - ) - shark_module.set_frontend("tensorflow") - shark_module.compile() - - print( - shark_module.forward( - ( - encoded_input["input_ids"], - encoded_input["attention_mask"], - encoded_input["token_type_ids"], - ) - ) - ) diff --git a/shark/examples/shark_inference/minilm_tf_gpu_config.json b/shark/examples/shark_inference/minilm_tf_gpu_config.json deleted file mode 100644 index 2cb764c709..0000000000 --- a/shark/examples/shark_inference/minilm_tf_gpu_config.json +++ /dev/null @@ -1 +0,0 @@ -{"options": [{"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [1, 32, 64, 16], "work_group_sizes": [64, 1, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [1, 32, 32, 16], "work_group_sizes": [64, 1, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 64, 32], "work_group_sizes": [128, 1, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [1, 64, 64, 32], "work_group_sizes": [128, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4, "split_k": 8}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [1, 32, 64, 16], "work_group_sizes": [64, 1, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [1, 32, 32, 16], "work_group_sizes": [64, 1, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 64, 32], "work_group_sizes": [128, 1, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [1, 64, 64, 32], "work_group_sizes": [128, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4, "split_k": 8}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [1, 32, 64, 16], "work_group_sizes": [64, 1, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [1, 32, 32, 16], "work_group_sizes": [64, 1, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 64, 32], "work_group_sizes": [128, 1, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [1, 64, 64, 32], "work_group_sizes": [128, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4, "split_k": 8}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [1, 32, 64, 16], "work_group_sizes": [64, 1, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [1, 32, 32, 16], "work_group_sizes": [64, 1, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 64, 32], "work_group_sizes": [128, 1, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [1, 64, 64, 32], "work_group_sizes": [128, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4, "split_k": 8}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [1, 32, 64, 16], "work_group_sizes": [64, 1, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [1, 32, 32, 16], "work_group_sizes": [64, 1, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 64, 32], "work_group_sizes": [128, 1, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [1, 64, 64, 32], "work_group_sizes": [128, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4, "split_k": 8}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [1, 32, 64, 16], "work_group_sizes": [64, 1, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [1, 32, 32, 16], "work_group_sizes": [64, 1, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 64, 32], "work_group_sizes": [128, 1, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [1, 64, 64, 32], "work_group_sizes": [128, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4, "split_k": 8}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [1, 32, 64, 16], "work_group_sizes": [64, 1, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [1, 32, 32, 16], "work_group_sizes": [64, 1, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 64, 32], "work_group_sizes": [128, 1, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [1, 64, 64, 32], "work_group_sizes": [128, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4, "split_k": 8}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [1, 32, 64, 16], "work_group_sizes": [64, 1, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [1, 32, 32, 16], "work_group_sizes": [64, 1, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 64, 32], "work_group_sizes": [128, 1, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [1, 64, 64, 32], "work_group_sizes": [128, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4, "split_k": 8}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [1, 32, 64, 16], "work_group_sizes": [64, 1, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [1, 32, 32, 16], "work_group_sizes": [64, 1, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 64, 32], "work_group_sizes": [128, 1, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [1, 64, 64, 32], "work_group_sizes": [128, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4, "split_k": 8}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [1, 32, 64, 16], "work_group_sizes": [64, 1, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [1, 32, 32, 16], "work_group_sizes": [64, 1, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 64, 32], "work_group_sizes": [128, 1, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [1, 64, 64, 32], "work_group_sizes": [128, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4, "split_k": 8}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [1, 32, 64, 16], "work_group_sizes": [64, 1, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [1, 32, 32, 16], "work_group_sizes": [64, 1, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 64, 32], "work_group_sizes": [128, 1, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [1, 64, 64, 32], "work_group_sizes": [128, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4, "split_k": 8}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [1, 32, 64, 16], "work_group_sizes": [64, 1, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [1, 32, 32, 16], "work_group_sizes": [64, 1, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 32, 16], "work_group_sizes": [64, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [32, 64, 32], "work_group_sizes": [128, 1, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4}, {"work_group_tile_sizes": [1, 64, 64, 32], "work_group_sizes": [128, 2, 1], "pipeline": "GPU_TENSORCORE", "pipeline_depth": 4, "split_k": 8}, {"work_group_tile_sizes": [1, 32, 128], "work_group_sizes": [32, 1, 1], "pipeline": "GPU"}]} \ No newline at end of file diff --git a/shark/examples/shark_inference/resnest.py b/shark/examples/shark_inference/resnest.py deleted file mode 100644 index cfb81f3f5f..0000000000 --- a/shark/examples/shark_inference/resnest.py +++ /dev/null @@ -1,39 +0,0 @@ -import torch -import torchvision.models as models -from shark.shark_inference import SharkInference -from shark.shark_importer import SharkImporter - -torch.hub.list("zhanghang1989/ResNeSt", force_reload=True) - - -class ResnestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.model = torch.hub.load( - "zhanghang1989/ResNeSt", "resnest50", pretrained=True - ) - self.model.eval() - - def forward(self, input): - return self.model.forward(input) - - -input = torch.randn(1, 3, 224, 224) - - -mlir_importer = SharkImporter( - ResnestModule(), - (input,), - frontend="torch", -) - -(vision_mlir, func_name), inputs, golden_out = mlir_importer.import_debug( - tracing_required=True -) - -print(golden_out) - -shark_module = SharkInference(vision_mlir, mlir_dialect="linalg") -shark_module.compile() -result = shark_module.forward((input,)) -print("Obtained result", result) diff --git a/shark/examples/shark_inference/resnet50_fp16.py b/shark/examples/shark_inference/resnet50_fp16.py deleted file mode 100644 index 4fe6aa4079..0000000000 --- a/shark/examples/shark_inference/resnet50_fp16.py +++ /dev/null @@ -1,74 +0,0 @@ -from shark.shark_inference import SharkInference -from shark.parser import shark_args - -import torch -import numpy as np -import sys -import torchvision.models as models -import torch_mlir - -torch.manual_seed(0) - - -class VisionModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.model = models.resnet50(pretrained=True) - self.train(False) - - def forward(self, input): - return self.model.forward(input) - - -model = VisionModule() -test_input = torch.randn(1, 3, 224, 224) -actual_out = model(test_input) - -test_input_fp16 = test_input.to(device=torch.device("cuda"), dtype=torch.half) -model_fp16 = model.half() -model_fp16.eval() -model_fp16.to("cuda") -actual_out_fp16 = model_fp16(test_input_fp16) - -ts_g = torch.jit.trace(model_fp16, [test_input_fp16]) - -module = torch_mlir.compile( - ts_g, - (test_input_fp16), - torch_mlir.OutputType.LINALG_ON_TENSORS, - use_tracing=True, - verbose=False, -) - -# from contextlib import redirect_stdout - -# with open('resnet50_fp16_linalg_ir.mlir', 'w') as f: -# with redirect_stdout(f): -# print(module.operation.get_asm()) - -mlir_model = module -func_name = "forward" - -shark_module = SharkInference(mlir_model, device="cuda", mlir_dialect="linalg") -shark_module.compile() - - -def shark_result(x): - x_ny = x.cpu().detach().numpy() - inputs = (x_ny,) - result = shark_module.forward(inputs) - return torch.from_numpy(result) - - -observed_out = shark_result(test_input_fp16) - -print("Golden result:", actual_out_fp16) -print("SHARK result:", observed_out) - -actual_out_fp16 = actual_out_fp16.to(device=torch.device("cpu")) - -print( - torch.testing.assert_allclose( - actual_out_fp16, observed_out, rtol=1e-2, atol=1e-2 - ) -) diff --git a/shark/examples/shark_inference/resnet50_script.py b/shark/examples/shark_inference/resnet50_script.py deleted file mode 100644 index e597d42932..0000000000 --- a/shark/examples/shark_inference/resnet50_script.py +++ /dev/null @@ -1,85 +0,0 @@ -from PIL import Image -import requests -import torch -import torchvision.models as models -from torchvision import transforms -import sys -from shark.shark_inference import SharkInference -from shark.shark_downloader import download_model - - -################################## Preprocessing inputs and model ############ -def load_and_preprocess_image(url: str): - headers = { - "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36" - } - img = Image.open( - requests.get(url, headers=headers, stream=True).raw - ).convert("RGB") - # preprocessing pipeline - preprocess = transforms.Compose( - [ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] - ), - ] - ) - img_preprocessed = preprocess(img) - return torch.unsqueeze(img_preprocessed, 0) - - -def load_labels(): - classes_text = requests.get( - "https://raw.githubusercontent.com/cathyzhyi/ml-data/main/imagenet-classes.txt", - stream=True, - ).text - labels = [line.strip() for line in classes_text.splitlines()] - return labels - - -def top3_possibilities(res): - _, indexes = torch.sort(res, descending=True) - percentage = torch.nn.functional.softmax(res, dim=1)[0] * 100 - top3 = [(labels[idx], percentage[idx].item()) for idx in indexes[0][:3]] - return top3 - - -class Resnet50Module(torch.nn.Module): - def __init__(self): - super().__init__() - self.resnet = models.resnet50(pretrained=True) - self.train(False) - - def forward(self, img): - return self.resnet.forward(img) - - -image_url = "https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg" -print("load image from " + image_url, file=sys.stderr) -img = load_and_preprocess_image(image_url) -labels = load_labels() - -############################################################################## - - -## Can pass any img or input to the forward module. -mlir_model, func_name, inputs, golden_out = download_model( - "resnet50", frontend="torch" -) - -shark_module = SharkInference(mlir_model, mlir_dialect="linalg") -shark_module.compile() -path = shark_module.save_module() -shark_module.load_module(path) -result = shark_module("forward", (img.detach().numpy(),)) - -print("The top 3 results obtained via shark_runner is:") -print(top3_possibilities(torch.from_numpy(result))) - -print() - -print("The top 3 results obtained via torch is:") -print(top3_possibilities(Resnet50Module()(img))) diff --git a/shark/examples/shark_inference/sharded_bloom.py b/shark/examples/shark_inference/sharded_bloom.py deleted file mode 100644 index 3e8df14f2d..0000000000 --- a/shark/examples/shark_inference/sharded_bloom.py +++ /dev/null @@ -1,842 +0,0 @@ -#################################################################################### -# Please make sure you have transformers 4.21.2 installed before running this demo -# -# -p --model_path: the directory in which you want to store the bloom files. -# -dl --device_list: the list of device indices you want to use. if you want to only use the first device, or you are running on cpu leave this blank. -# Otherwise, please give this argument in this format: "[0, 1, 2]" -# -de --device: the device you want to run bloom on. E.G. cpu, cuda -# -c, --recompile: set to true if you want to recompile to vmfb. -# -d, --download: set to true if you want to redownload the mlir files -# -cm, --create_mlirs: set to true if you want to create the mlir files from scratch. please make sure you have transformers 4.21.2 before using this option -# -t --token_count: the number of tokens you want to generate -# -pr --prompt: the prompt you want to feed to the model -# -m --model_name: the name of the model, e.g. bloom-560m -# -# If you don't specify a prompt when you run this example, you will be able to give prompts through the terminal. Run the -# example in this way if you want to run multiple examples without reinitializing the model -##################################################################################### - -import os -import io -import torch -import torch.nn as nn -from collections import OrderedDict -import torch_mlir -from torch_mlir import TensorPlaceholder -import re -from transformers.models.bloom.configuration_bloom import BloomConfig -import json -import sys -import argparse -import json -import urllib.request -import subprocess - -from torch.fx.experimental.proxy_tensor import make_fx -from torch._decomp import get_decompositions -from shark.shark_inference import SharkInference -from shark.shark_downloader import download_public_file -from transformers import ( - BloomTokenizerFast, - BloomForSequenceClassification, - BloomForCausalLM, -) -from transformers.models.bloom.modeling_bloom import ( - BloomBlock, - build_alibi_tensor, -) - -IS_CUDA = False - - -class ShardedBloom: - def __init__(self, src_folder): - f = open(f"{src_folder}/config.json") - config = json.load(f) - f.close() - - self.layers_initialized = False - - self.src_folder = src_folder - try: - self.n_embed = config["n_embed"] - except KeyError: - self.n_embed = config["hidden_size"] - self.vocab_size = config["vocab_size"] - self.n_layer = config["n_layer"] - try: - self.n_head = config["num_attention_heads"] - except KeyError: - self.n_head = config["n_head"] - - def _init_layer(self, layer_name, device, replace, device_idx): - if replace or not os.path.exists( - f"{self.src_folder}/{layer_name}.vmfb" - ): - f_ = open(f"{self.src_folder}/{layer_name}.mlir", encoding="utf-8") - module = f_.read() - f_.close() - module = bytes(module, "utf-8") - shark_module = SharkInference( - module, - device=device, - mlir_dialect="tm_tensor", - device_idx=device_idx, - ) - shark_module.save_module( - module_name=f"{self.src_folder}/{layer_name}", - extra_args=[ - "--iree-vm-bytecode-module-output-format=flatbuffer-binary", - "--iree-stream-resource-max-allocation-size=1000000000", - "--iree-codegen-check-ir-before-llvm-conversion=false", - ], - ) - else: - shark_module = SharkInference( - "", - device=device, - mlir_dialect="tm_tensor", - device_idx=device_idx, - ) - - return shark_module - - def init_layers(self, device, replace=False, device_idx=[0]): - if device_idx is not None: - n_devices = len(device_idx) - - self.word_embeddings_module = self._init_layer( - "word_embeddings", - device, - replace, - device_idx if device_idx is None else device_idx[0 % n_devices], - ) - self.word_embeddings_layernorm_module = self._init_layer( - "word_embeddings_layernorm", - device, - replace, - device_idx if device_idx is None else device_idx[1 % n_devices], - ) - self.ln_f_module = self._init_layer( - "ln_f", - device, - replace, - device_idx if device_idx is None else device_idx[2 % n_devices], - ) - self.lm_head_module = self._init_layer( - "lm_head", - device, - replace, - device_idx if device_idx is None else device_idx[3 % n_devices], - ) - self.block_modules = [ - self._init_layer( - f"bloom_block_{i}", - device, - replace, - device_idx - if device_idx is None - else device_idx[(i + 4) % n_devices], - ) - for i in range(self.n_layer) - ] - - self.layers_initialized = True - - def load_layers(self): - assert self.layers_initialized - - self.word_embeddings_module.load_module( - f"{self.src_folder}/word_embeddings.vmfb" - ) - self.word_embeddings_layernorm_module.load_module( - f"{self.src_folder}/word_embeddings_layernorm.vmfb" - ) - for block_module, i in zip(self.block_modules, range(self.n_layer)): - block_module.load_module(f"{self.src_folder}/bloom_block_{i}.vmfb") - self.ln_f_module.load_module(f"{self.src_folder}/ln_f.vmfb") - self.lm_head_module.load_module(f"{self.src_folder}/lm_head.vmfb") - - def forward_pass(self, input_ids, device): - if IS_CUDA: - cudaSetDevice(self.word_embeddings_module.device_idx) - - input_embeds = self.word_embeddings_module( - inputs=(input_ids,), function_name="forward" - ) - - input_embeds = torch.tensor(input_embeds).float() - if IS_CUDA: - cudaSetDevice(self.word_embeddings_layernorm_module.device_idx) - hidden_states = self.word_embeddings_layernorm_module( - inputs=(input_embeds,), function_name="forward" - ) - - hidden_states = torch.tensor(hidden_states).float() - - attention_mask = torch.ones( - [hidden_states.shape[0], len(input_ids[0])] - ) - alibi = build_alibi_tensor( - attention_mask, - self.n_head, - hidden_states.dtype, - hidden_states.device, - ) - - causal_mask = _prepare_attn_mask( - attention_mask, input_ids.size(), input_embeds, 0 - ) - causal_mask = torch.tensor(causal_mask).float() - - presents = () - all_hidden_states = tuple(hidden_states) - - for block_module, i in zip(self.block_modules, range(self.n_layer)): - if IS_CUDA: - cudaSetDevice(block_module.device_idx) - - output = block_module( - inputs=( - hidden_states.detach().numpy(), - alibi.detach().numpy(), - causal_mask.detach().numpy(), - ), - function_name="forward", - ) - hidden_states = torch.tensor(output[0]).float() - all_hidden_states = all_hidden_states + (hidden_states,) - presents = presents + ( - tuple( - ( - output[1], - output[2], - ) - ), - ) - if IS_CUDA: - cudaSetDevice(self.ln_f_module.device_idx) - - hidden_states = self.ln_f_module( - inputs=(hidden_states,), function_name="forward" - ) - if IS_CUDA: - cudaSetDevice(self.lm_head_module.device_idx) - - logits = self.lm_head_module( - inputs=(hidden_states,), function_name="forward" - ) - logits = torch.tensor(logits).float() - - return torch.argmax(logits[:, -1, :], dim=-1) - - -def _make_causal_mask( - input_ids_shape: torch.Size, - dtype: torch.dtype, - past_key_values_length: int = 0, -): - """ - Make causal mask used for bi-directional self-attention. - """ - batch_size, target_length = input_ids_shape - mask = torch.full((target_length, target_length), torch.finfo(dtype).min) - mask_cond = torch.arange(mask.size(-1)) - intermediate_mask = mask_cond < (mask_cond + 1).view(mask.size(-1), 1) - mask.masked_fill_(intermediate_mask, 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat( - [ - torch.zeros( - target_length, past_key_values_length, dtype=dtype - ), - mask, - ], - dim=-1, - ) - expanded_mask = mask[None, None, :, :].expand( - batch_size, 1, target_length, target_length + past_key_values_length - ) - return expanded_mask - - -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - batch_size, source_length = mask.size() - tgt_len = tgt_len if tgt_len is not None else source_length - - expanded_mask = ( - mask[:, None, None, :] - .expand(batch_size, 1, tgt_len, source_length) - .to(dtype) - ) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill( - inverted_mask.to(torch.bool), torch.finfo(dtype).min - ) - - -def _prepare_attn_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length -): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - past_key_values_length=past_key_values_length, - ).to(attention_mask.device) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask( - attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) - combined_attention_mask = ( - expanded_attn_mask - if combined_attention_mask is None - else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - -def download_model(destination_folder, model_name): - download_public_file( - f"gs://shark_tank/sharded_bloom/{model_name}/", destination_folder - ) - - -def compile_embeddings(embeddings_layer, input_ids, path): - input_ids_placeholder = torch_mlir.TensorPlaceholder.like( - input_ids, dynamic_axes=[1] - ) - module = torch_mlir.compile( - embeddings_layer, - (input_ids_placeholder), - torch_mlir.OutputType.LINALG_ON_TENSORS, - use_tracing=False, - verbose=False, - ) - - bytecode_stream = io.BytesIO() - module.operation.write_bytecode(bytecode_stream) - bytecode = bytecode_stream.getvalue() - - f_ = open(path, "w+") - f_.write(str(module)) - f_.close() - return - - -def compile_word_embeddings_layernorm( - embeddings_layer_layernorm, embeds, path -): - embeds_placeholder = torch_mlir.TensorPlaceholder.like( - embeds, dynamic_axes=[1] - ) - module = torch_mlir.compile( - embeddings_layer_layernorm, - (embeds_placeholder), - torch_mlir.OutputType.LINALG_ON_TENSORS, - use_tracing=False, - verbose=False, - ) - - bytecode_stream = io.BytesIO() - module.operation.write_bytecode(bytecode_stream) - bytecode = bytecode_stream.getvalue() - - f_ = open(path, "w+") - f_.write(str(module)) - f_.close() - return - - -def strip_overloads(gm): - """ - Modifies the target of graph nodes in :attr:`gm` to strip overloads. - Args: - gm(fx.GraphModule): The input Fx graph module to be modified - """ - for node in gm.graph.nodes: - if isinstance(node.target, torch._ops.OpOverload): - node.target = node.target.overloadpacket - gm.recompile() - - -def compile_to_mlir( - bblock, - hidden_states, - layer_past=None, - attention_mask=None, - head_mask=None, - use_cache=None, - output_attentions=False, - alibi=None, - block_index=0, - path=".", -): - fx_g = make_fx( - bblock, - decomposition_table=get_decompositions( - [ - torch.ops.aten.split.Tensor, - torch.ops.aten.split_with_sizes, - ] - ), - tracing_mode="real", - _allow_non_fake_inputs=False, - )(hidden_states, alibi, attention_mask) - - fx_g.graph.set_codegen(torch.fx.graph.CodeGen()) - fx_g.recompile() - - strip_overloads(fx_g) - - hidden_states_placeholder = TensorPlaceholder.like( - hidden_states, dynamic_axes=[1] - ) - attention_mask_placeholder = TensorPlaceholder.like( - attention_mask, dynamic_axes=[2, 3] - ) - alibi_placeholder = TensorPlaceholder.like(alibi, dynamic_axes=[2]) - - ts_g = torch.jit.script(fx_g) - - module = torch_mlir.compile( - ts_g, - ( - hidden_states_placeholder, - alibi_placeholder, - attention_mask_placeholder, - ), - torch_mlir.OutputType.LINALG_ON_TENSORS, - use_tracing=False, - verbose=False, - ) - - module_placeholder = module - module_context = module_placeholder.context - - def check_valid_line(line, line_n, mlir_file_len): - if "private" in line: - return False - if "attributes" in line: - return False - if mlir_file_len - line_n == 2: - return False - - return True - - mlir_file_len = len(str(module).split("\n")) - - def remove_constant_dim(line): - if "17x" in line: - line = re.sub("17x", "?x", line) - line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line) - if "tensor.empty" in line and "?x?" in line: - line = re.sub( - "tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line - ) - if "arith.cmpi eq" in line: - line = re.sub("c17", "dim", line) - if " 17," in line: - line = re.sub(" 17,", " %dim,", line) - return line - - module = "\n".join( - [ - remove_constant_dim(line) - for line, line_n in zip( - str(module).split("\n"), range(mlir_file_len) - ) - if check_valid_line(line, line_n, mlir_file_len) - ] - ) - - module = module_placeholder.parse(module, context=module_context) - bytecode_stream = io.BytesIO() - module.operation.write_bytecode(bytecode_stream) - bytecode = bytecode_stream.getvalue() - - f_ = open(path, "w+") - f_.write(str(module)) - f_.close() - return - - -def compile_ln_f(ln_f, hidden_layers, path): - hidden_layers_placeholder = torch_mlir.TensorPlaceholder.like( - hidden_layers, dynamic_axes=[1] - ) - module = torch_mlir.compile( - ln_f, - (hidden_layers_placeholder), - torch_mlir.OutputType.LINALG_ON_TENSORS, - use_tracing=False, - verbose=False, - ) - - bytecode_stream = io.BytesIO() - module.operation.write_bytecode(bytecode_stream) - bytecode = bytecode_stream.getvalue() - - f_ = open(path, "w+") - f_.write(str(module)) - f_.close() - return - - -def compile_lm_head(lm_head, hidden_layers, path): - hidden_layers_placeholder = torch_mlir.TensorPlaceholder.like( - hidden_layers, dynamic_axes=[1] - ) - module = torch_mlir.compile( - lm_head, - (hidden_layers_placeholder), - torch_mlir.OutputType.LINALG_ON_TENSORS, - use_tracing=False, - verbose=False, - ) - - bytecode_stream = io.BytesIO() - module.operation.write_bytecode(bytecode_stream) - bytecode = bytecode_stream.getvalue() - - f_ = open(path, "w+") - f_.write(str(module)) - f_.close() - return - - -def create_mlirs(destination_folder, model_name): - model_config = "bigscience/" + model_name - sample_input_ids = torch.ones([1, 17], dtype=torch.int64) - - urllib.request.urlretrieve( - f"https://huggingface.co/bigscience/{model_name}/resolve/main/config.json", - filename=f"{destination_folder}/config.json", - ) - urllib.request.urlretrieve( - f"https://huggingface.co/bigscience/bloom/resolve/main/tokenizer.json", - filename=f"{destination_folder}/tokenizer.json", - ) - - class HuggingFaceLanguage(torch.nn.Module): - def __init__(self): - super().__init__() - self.model = BloomForCausalLM.from_pretrained(model_config) - - def forward(self, tokens): - return self.model.forward(tokens)[0] - - class HuggingFaceBlock(torch.nn.Module): - def __init__(self, block): - super().__init__() - self.model = block - - def forward(self, tokens, alibi, attention_mask): - output = self.model( - hidden_states=tokens, - alibi=alibi, - attention_mask=attention_mask, - use_cache=True, - output_attentions=False, - ) - return (output[0], output[1][0], output[1][1]) - - model = HuggingFaceLanguage() - - compile_embeddings( - model.model.transformer.word_embeddings, - sample_input_ids, - f"{destination_folder}/word_embeddings.mlir", - ) - - inputs_embeds = model.model.transformer.word_embeddings(sample_input_ids) - - compile_word_embeddings_layernorm( - model.model.transformer.word_embeddings_layernorm, - inputs_embeds, - f"{destination_folder}/word_embeddings_layernorm.mlir", - ) - - hidden_states = model.model.transformer.word_embeddings_layernorm( - inputs_embeds - ) - - input_shape = sample_input_ids.size() - - current_sequence_length = hidden_states.shape[1] - past_key_values_length = 0 - past_key_values = tuple([None] * len(model.model.transformer.h)) - - attention_mask = torch.ones( - (hidden_states.shape[0], current_sequence_length), device="cpu" - ) - - alibi = build_alibi_tensor( - attention_mask, - model.model.transformer.n_head, - hidden_states.dtype, - "cpu", - ) - - causal_mask = _prepare_attn_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - - head_mask = model.model.transformer.get_head_mask( - None, model.model.transformer.config.n_layer - ) - output_attentions = model.model.transformer.config.output_attentions - - all_hidden_states = () - - for i, (block, layer_past) in enumerate( - zip(model.model.transformer.h, past_key_values) - ): - all_hidden_states = all_hidden_states + (hidden_states,) - - proxy_model = HuggingFaceBlock(block) - - compile_to_mlir( - proxy_model, - hidden_states, - layer_past=layer_past, - attention_mask=causal_mask, - head_mask=head_mask[i], - use_cache=True, - output_attentions=output_attentions, - alibi=alibi, - block_index=i, - path=f"{destination_folder}/bloom_block_{i}.mlir", - ) - - compile_ln_f( - model.model.transformer.ln_f, - hidden_states, - f"{destination_folder}/ln_f.mlir", - ) - hidden_states = model.model.transformer.ln_f(hidden_states) - compile_lm_head( - model.model.lm_head, - hidden_states, - f"{destination_folder}/lm_head.mlir", - ) - - -def run_large_model( - token_count, - recompile, - model_path, - prompt, - device_list, - script_path, - device, -): - f = open(f"{model_path}/prompt.txt", "w+") - f.write(prompt) - f.close() - for i in range(token_count): - if i == 0: - will_compile = recompile - else: - will_compile = False - f = open(f"{model_path}/prompt.txt", "r") - prompt = f.read() - f.close() - - subprocess.run( - [ - "python", - script_path, - model_path, - "start", - str(will_compile), - "cpu", - "None", - prompt, - ] - ) - for i in range(config["n_layer"]): - if device_list is not None: - device_idx = str(device_list[i % len(device_list)]) - else: - device_idx = "None" - subprocess.run( - [ - "python", - script_path, - model_path, - str(i), - str(will_compile), - device, - device_idx, - prompt, - ] - ) - subprocess.run( - [ - "python", - script_path, - model_path, - "end", - str(will_compile), - "cpu", - "None", - prompt, - ] - ) - - f = open(f"{model_path}/prompt.txt", "r") - output = f.read() - f.close() - print(output) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(prog="Bloom-560m") - parser.add_argument("-p", "--model_path") - parser.add_argument("-dl", "--device_list", default=None) - parser.add_argument("-de", "--device", default="cpu") - parser.add_argument("-c", "--recompile", default=False, type=bool) - parser.add_argument("-d", "--download", default=False, type=bool) - parser.add_argument("-t", "--token_count", default=10, type=int) - parser.add_argument("-m", "--model_name", default="bloom-560m") - parser.add_argument("-cm", "--create_mlirs", default=False, type=bool) - - parser.add_argument( - "-lm", "--large_model_memory_efficient", default=False, type=bool - ) - - parser.add_argument( - "-pr", - "--prompt", - default=None, - ) - args = parser.parse_args() - - if args.create_mlirs and args.large_model_memory_efficient: - print( - "Warning: If you need to use memory efficient mode, you probably want to use 'download' instead" - ) - - if not os.path.isdir(args.model_path): - os.mkdir(args.model_path) - - if args.device_list is not None: - args.device_list = json.loads(args.device_list) - - if args.device == "cuda" and args.device_list is not None: - IS_CUDA = True - from cuda.cudart import cudaSetDevice - if args.download and args.create_mlirs: - print( - "WARNING: It is not advised to turn on both download and create_mlirs" - ) - if args.download: - download_model(args.model_path, args.model_name) - if args.create_mlirs: - create_mlirs(args.model_path, args.model_name) - from transformers import AutoTokenizer, AutoModelForCausalLM, BloomConfig - - tokenizer = AutoTokenizer.from_pretrained(args.model_path) - if args.prompt is not None: - input_ids = tokenizer.encode(args.prompt, return_tensors="pt") - - if args.large_model_memory_efficient: - f = open(f"{args.model_path}/config.json") - config = json.load(f) - f.close() - - self_path = os.path.dirname(os.path.abspath(__file__)) - script_path = os.path.join(self_path, "sharded_bloom_large_models.py") - - if args.prompt is not None: - run_large_model( - args.token_count, - args.recompile, - args.model_path, - args.prompt, - args.device_list, - script_path, - args.device, - ) - - else: - while True: - prompt = input("Enter Prompt: ") - try: - token_count = int( - input("Enter number of tokens you want to generate: ") - ) - except: - print( - "Invalid integer entered. Using default value of 10" - ) - token_count = 10 - - run_large_model( - token_count, - args.recompile, - args.model_path, - prompt, - args.device_list, - script_path, - args.device, - ) - - else: - shardedbloom = ShardedBloom(args.model_path) - shardedbloom.init_layers( - device=args.device, - replace=args.recompile, - device_idx=args.device_list, - ) - shardedbloom.load_layers() - - if args.prompt is not None: - for _ in range(args.token_count): - next_token = shardedbloom.forward_pass( - torch.tensor(input_ids), device=args.device - ) - input_ids = torch.cat( - [input_ids, next_token.unsqueeze(-1)], dim=-1 - ) - - print(tokenizer.decode(input_ids.squeeze())) - - else: - while True: - prompt = input("Enter Prompt: ") - try: - token_count = int( - input("Enter number of tokens you want to generate: ") - ) - except: - print( - "Invalid integer entered. Using default value of 10" - ) - token_count = 10 - - input_ids = tokenizer.encode(prompt, return_tensors="pt") - - for _ in range(token_count): - next_token = shardedbloom.forward_pass( - torch.tensor(input_ids), device=args.device - ) - input_ids = torch.cat( - [input_ids, next_token.unsqueeze(-1)], dim=-1 - ) - - print(tokenizer.decode(input_ids.squeeze())) diff --git a/shark/examples/shark_inference/sharded_bloom_large_models.py b/shark/examples/shark_inference/sharded_bloom_large_models.py deleted file mode 100644 index 1635ac135c..0000000000 --- a/shark/examples/shark_inference/sharded_bloom_large_models.py +++ /dev/null @@ -1,381 +0,0 @@ -import sys -import os -from transformers import AutoTokenizer, AutoModelForCausalLM, BloomConfig -import re -from shark.shark_inference import SharkInference -import torch -import torch.nn as nn -from collections import OrderedDict -from transformers.models.bloom.modeling_bloom import ( - BloomBlock, - build_alibi_tensor, -) -import time -import json - - -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - batch_size, source_length = mask.size() - tgt_len = tgt_len if tgt_len is not None else source_length - - expanded_mask = ( - mask[:, None, None, :] - .expand(batch_size, 1, tgt_len, source_length) - .to(dtype) - ) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill( - inverted_mask.to(torch.bool), torch.finfo(dtype).min - ) - - -def _prepare_attn_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length -): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - past_key_values_length=past_key_values_length, - ).to(attention_mask.device) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask( - attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) - combined_attention_mask = ( - expanded_attn_mask - if combined_attention_mask is None - else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - -def _make_causal_mask( - input_ids_shape: torch.Size, - dtype: torch.dtype, - past_key_values_length: int = 0, -): - """ - Make causal mask used for bi-directional self-attention. - """ - batch_size, target_length = input_ids_shape - mask = torch.full((target_length, target_length), torch.finfo(dtype).min) - mask_cond = torch.arange(mask.size(-1)) - intermediate_mask = mask_cond < (mask_cond + 1).view(mask.size(-1), 1) - mask.masked_fill_(intermediate_mask, 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat( - [ - torch.zeros( - target_length, past_key_values_length, dtype=dtype - ), - mask, - ], - dim=-1, - ) - expanded_mask = mask[None, None, :, :].expand( - batch_size, 1, target_length, target_length + past_key_values_length - ) - return expanded_mask - - -if __name__ == "__main__": - working_dir = sys.argv[1] - layer_name = sys.argv[2] - will_compile = sys.argv[3] - device = sys.argv[4] - device_idx = sys.argv[5] - prompt = sys.argv[6] - - if device_idx.lower().strip() == "none": - device_idx = None - else: - device_idx = int(device_idx) - - if will_compile.lower().strip() == "true": - will_compile = True - else: - will_compile = False - - f = open(f"{working_dir}/config.json") - config = json.load(f) - f.close() - - layers_initialized = False - try: - n_embed = config["n_embed"] - except KeyError: - n_embed = config["hidden_size"] - vocab_size = config["vocab_size"] - n_layer = config["n_layer"] - try: - n_head = config["num_attention_heads"] - except KeyError: - n_head = config["n_head"] - - if not os.path.isdir(working_dir): - os.mkdir(working_dir) - - if layer_name == "start": - tokenizer = AutoTokenizer.from_pretrained(working_dir) - input_ids = tokenizer.encode(prompt, return_tensors="pt") - - mlir_str = "" - - if will_compile: - f = open(f"{working_dir}/word_embeddings.mlir", encoding="utf-8") - mlir_str = f.read() - f.close() - - mlir_str = bytes(mlir_str, "utf-8") - - shark_module = SharkInference( - mlir_str, - device="cpu", - mlir_dialect="tm_tensor", - device_idx=None, - ) - - if will_compile: - shark_module.save_module( - module_name=f"{working_dir}/word_embeddings", - extra_args=[ - "--iree-vm-bytecode-module-output-format=flatbuffer-binary", - "--iree-stream-resource-max-allocation-size=1000000000", - "--iree-codegen-check-ir-before-llvm-conversion=false", - ], - ) - - shark_module.load_module(f"{working_dir}/word_embeddings.vmfb") - input_embeds = shark_module( - inputs=(input_ids,), function_name="forward" - ) - input_embeds = torch.tensor(input_embeds).float() - - mlir_str = "" - - if will_compile: - f = open( - f"{working_dir}/word_embeddings_layernorm.mlir", - encoding="utf-8", - ) - mlir_str = f.read() - f.close() - - shark_module = SharkInference( - mlir_str, - device="cpu", - mlir_dialect="tm_tensor", - device_idx=None, - ) - - if will_compile: - shark_module.save_module( - module_name=f"{working_dir}/word_embeddings_layernorm", - extra_args=[ - "--iree-vm-bytecode-module-output-format=flatbuffer-binary", - "--iree-stream-resource-max-allocation-size=1000000000", - "--iree-codegen-check-ir-before-llvm-conversion=false", - ], - ) - - shark_module.load_module( - f"{working_dir}/word_embeddings_layernorm.vmfb" - ) - hidden_states = shark_module( - inputs=(input_embeds,), function_name="forward" - ) - hidden_states = torch.tensor(hidden_states).float() - - torch.save(hidden_states, f"{working_dir}/hidden_states_0.pt") - - attention_mask = torch.ones( - [hidden_states.shape[0], len(input_ids[0])] - ) - - attention_mask = torch.tensor(attention_mask).float() - - alibi = build_alibi_tensor( - attention_mask, - n_head, - hidden_states.dtype, - device="cpu", - ) - - torch.save(alibi, f"{working_dir}/alibi.pt") - - causal_mask = _prepare_attn_mask( - attention_mask, input_ids.size(), input_embeds, 0 - ) - causal_mask = torch.tensor(causal_mask).float() - - torch.save(causal_mask, f"{working_dir}/causal_mask.pt") - - elif layer_name in [str(x) for x in range(n_layer)]: - hidden_states = torch.load( - f"{working_dir}/hidden_states_{layer_name}.pt" - ) - alibi = torch.load(f"{working_dir}/alibi.pt") - causal_mask = torch.load(f"{working_dir}/causal_mask.pt") - - mlir_str = "" - - if will_compile: - f = open( - f"{working_dir}/bloom_block_{layer_name}.mlir", - encoding="utf-8", - ) - mlir_str = f.read() - f.close() - - mlir_str = bytes(mlir_str, "utf-8") - - shark_module = SharkInference( - mlir_str, - device=device, - mlir_dialect="tm_tensor", - device_idx=device_idx, - ) - - if will_compile: - shark_module.save_module( - module_name=f"{working_dir}/bloom_block_{layer_name}", - extra_args=[ - "--iree-vm-bytecode-module-output-format=flatbuffer-binary", - "--iree-stream-resource-max-allocation-size=1000000000", - "--iree-codegen-check-ir-before-llvm-conversion=false", - ], - ) - - shark_module.load_module( - f"{working_dir}/bloom_block_{layer_name}.vmfb" - ) - - output = shark_module( - inputs=( - hidden_states.detach().numpy(), - alibi.detach().numpy(), - causal_mask.detach().numpy(), - ), - function_name="forward", - ) - - hidden_states = torch.tensor(output[0]).float() - - torch.save( - hidden_states, - f"{working_dir}/hidden_states_{int(layer_name) + 1}.pt", - ) - - elif layer_name == "end": - mlir_str = "" - - if will_compile: - f = open(f"{working_dir}/ln_f.mlir", encoding="utf-8") - mlir_str = f.read() - f.close() - - mlir_str = bytes(mlir_str, "utf-8") - - shark_module = SharkInference( - mlir_str, - device="cpu", - mlir_dialect="tm_tensor", - device_idx=None, - ) - - if will_compile: - shark_module.save_module( - module_name=f"{working_dir}/ln_f", - extra_args=[ - "--iree-vm-bytecode-module-output-format=flatbuffer-binary", - "--iree-stream-resource-max-allocation-size=1000000000", - "--iree-codegen-check-ir-before-llvm-conversion=false", - ], - ) - - shark_module.load_module(f"{working_dir}/ln_f.vmfb") - - hidden_states = torch.load(f"{working_dir}/hidden_states_{n_layer}.pt") - - hidden_states = shark_module( - inputs=(hidden_states,), function_name="forward" - ) - - mlir_str = "" - - if will_compile: - f = open(f"{working_dir}/lm_head.mlir", encoding="utf-8") - mlir_str = f.read() - f.close() - - mlir_str = bytes(mlir_str, "utf-8") - - if config["n_embed"] == 14336: - - def get_state_dict(): - d = torch.load( - f"{working_dir}/pytorch_model_00001-of-00072.bin" - ) - return OrderedDict( - (k.replace("word_embeddings.", ""), v) - for k, v in d.items() - ) - - def load_causal_lm_head(): - linear = nn.utils.skip_init( - nn.Linear, 14336, 250880, bias=False, dtype=torch.float - ) - linear.load_state_dict(get_state_dict(), strict=False) - return linear.float() - - lm_head = load_causal_lm_head() - - logits = lm_head(torch.tensor(hidden_states).float()) - - else: - shark_module = SharkInference( - mlir_str, - device="cpu", - mlir_dialect="tm_tensor", - device_idx=None, - ) - - if will_compile: - shark_module.save_module( - module_name=f"{working_dir}/lm_head", - extra_args=[ - "--iree-vm-bytecode-module-output-format=flatbuffer-binary", - "--iree-stream-resource-max-allocation-size=1000000000", - "--iree-codegen-check-ir-before-llvm-conversion=false", - ], - ) - - shark_module.load_module(f"{working_dir}/lm_head.vmfb") - - logits = shark_module( - inputs=(hidden_states,), function_name="forward" - ) - - logits = torch.tensor(logits).float() - - tokenizer = AutoTokenizer.from_pretrained(working_dir) - - next_token = tokenizer.decode(torch.argmax(logits[:, -1, :], dim=-1)) - - f = open(f"{working_dir}/prompt.txt", "w+") - f.write(prompt + next_token) - f.close() diff --git a/shark/examples/shark_inference/simple_dlrm.py b/shark/examples/shark_inference/simple_dlrm.py deleted file mode 100644 index fd1056a48f..0000000000 --- a/shark/examples/shark_inference/simple_dlrm.py +++ /dev/null @@ -1,390 +0,0 @@ -# Description: an implementation of a deep learning recommendation model (DLRM) -# The model input consists of dense and sparse features. The former is a vector -# of floating point values. The latter is a list of sparse indices into -# embedding tables, which consist of vectors of floating point values. -# The selected vectors are passed to mlp networks denoted by triangles, -# in some cases the vectors are interacted through operators (Ops). -# -# output: -# vector of values -# model: | -# /\ -# /__\ -# | -# _____________________> Op <___________________ -# / | \ -# /\ /\ /\ -# /__\ /__\ ... /__\ -# | | | -# | Op Op -# | ____/__\_____ ____/__\____ -# | |_Emb_|____|__| ... |_Emb_|__|___| -# input: -# [ dense features ] [sparse indices] , ..., [sparse indices] -# -# More precise definition of model layers: -# 1) fully connected layers of an mlp -# z = f(y) -# y = Wx + b -# -# 2) embedding lookup (for a list of sparse indices p=[p1,...,pk]) -# z = Op(e1,...,ek) -# obtain vectors e1=E[:,p1], ..., ek=E[:,pk] -# -# 3) Operator Op can be one of the following -# Sum(e1,...,ek) = e1 + ... + ek -# Dot(e1,...,ek) = [e1'e1, ..., e1'ek, ..., ek'e1, ..., ek'ek] -# Cat(e1,...,ek) = [e1', ..., ek']' -# where ' denotes transpose operation -# -# References: -# [1] Maxim Naumov, Dheevatsa Mudigere, Hao-Jun Michael Shi, Jianyu Huang, -# Narayanan Sundaram, Jongsoo Park, Xiaodong Wang, Udit Gupta, Carole-Jean Wu, -# Alisson G. Azzolini, Dmytro Dzhulgakov, Andrey Mallevich, Ilia Cherniavskii, -# Yinghai Lu, Raghuraman Krishnamoorthi, Ansha Yu, Volodymyr Kondratenko, -# Stephanie Pereira, Xianjie Chen, Wenlin Chen, Vijay Rao, Bill Jia, Liang Xiong, -# Misha Smelyanskiy, "Deep Learning Recommendation Model for Personalization and -# Recommendation Systems", CoRR, arXiv:1906.00091, 2019 - - -import argparse -import sys -import numpy as np -import torch -import torch.nn as nn -from shark.shark_inference import SharkInference -from shark.shark_importer import SharkImporter - - -torch.manual_seed(0) -np.random.seed(0) - - -### define dlrm in PyTorch ### -class DLRM_Net(nn.Module): - def create_mlp(self, ln, sigmoid_layer): - # build MLP layer by layer - layers = nn.ModuleList() - for i in range(0, ln.size - 1): - n = ln[i] - m = ln[i + 1] - - # construct fully connected operator - LL = nn.Linear(int(n), int(m), bias=True) - - # initialize the weights - # with torch.no_grad(): - # custom Xavier input, output or two-sided fill - - mean = 0.0 # std_dev = np.sqrt(variance) - std_dev = np.sqrt(2 / (m + n)) # np.sqrt(1 / m) # np.sqrt(1 / n) - W = np.random.normal(mean, std_dev, size=(m, n)).astype(np.float32) - std_dev = np.sqrt(1 / m) # np.sqrt(2 / (m + 1)) - bt = np.random.normal(mean, std_dev, size=m).astype(np.float32) - LL.weight.data = torch.tensor(W, requires_grad=True) - LL.bias.data = torch.tensor(bt, requires_grad=True) - - # approach 2 - # LL.weight.data.copy_(torch.tensor(W)) - # LL.bias.data.copy_(torch.tensor(bt)) - # approach 3 - # LL.weight = Parameter(torch.tensor(W),requires_grad=True) - # LL.bias = Parameter(torch.tensor(bt),requires_grad=True) - layers.append(LL) - - # construct sigmoid or relu operator - if i == sigmoid_layer: - layers.append(nn.Sigmoid()) - else: - layers.append(nn.ReLU()) - - # approach 1: use ModuleList - # return layers - # approach 2: use Sequential container to wrap all layers - return torch.nn.Sequential(*layers) - - def create_emb(self, m, ln, weighted_pooling=None): - emb_l = nn.ModuleList() - v_W_l = [] - for i in range(0, ln.size): - n = ln[i] - - # construct embedding operator - EE = nn.EmbeddingBag(n, m, mode="sum") - # initialize embeddings - # nn.init.uniform_(EE.weight, a=-np.sqrt(1 / n), b=np.sqrt(1 / n)) - W = np.random.uniform( - low=-np.sqrt(1 / n), high=np.sqrt(1 / n), size=(n, m) - ).astype(np.float32) - # approach 1 - print(W) - EE.weight.data = torch.tensor(W, requires_grad=True) - # approach 2 - # EE.weight.data.copy_(torch.tensor(W)) - # approach 3 - # EE.weight = Parameter(torch.tensor(W),requires_grad=True) - if weighted_pooling is None: - v_W_l.append(None) - else: - v_W_l.append(torch.ones(n, dtype=torch.float32)) - emb_l.append(EE) - return emb_l, v_W_l - - def __init__( - self, - m_spa=None, - ln_emb=None, - ln_bot=None, - ln_top=None, - arch_interaction_op=None, - arch_interaction_itself=False, - sigmoid_bot=-1, - sigmoid_top=-1, - weighted_pooling=None, - ): - super(DLRM_Net, self).__init__() - - if ( - (m_spa is not None) - and (ln_emb is not None) - and (ln_bot is not None) - and (ln_top is not None) - and (arch_interaction_op is not None) - ): - # save arguments - self.output_d = 0 - self.arch_interaction_op = arch_interaction_op - self.arch_interaction_itself = arch_interaction_itself - if weighted_pooling is not None and weighted_pooling != "fixed": - self.weighted_pooling = "learned" - else: - self.weighted_pooling = weighted_pooling - - # create operators - self.emb_l, w_list = self.create_emb( - m_spa, ln_emb, weighted_pooling - ) - if self.weighted_pooling == "learned": - self.v_W_l = nn.ParameterList() - for w in w_list: - self.v_W_l.append(nn.Parameter(w)) - else: - self.v_W_l = w_list - self.bot_l = self.create_mlp(ln_bot, sigmoid_bot) - self.top_l = self.create_mlp(ln_top, sigmoid_top) - - def apply_mlp(self, x, layers): - return layers(x) - - def apply_emb(self, lS_o, lS_i, emb_l, v_W_l): - # WARNING: notice that we are processing the batch at once. We implicitly - # assume that the data is laid out such that: - # 1. each embedding is indexed with a group of sparse indices, - # corresponding to a single lookup - # 2. for each embedding the lookups are further organized into a batch - # 3. for a list of embedding tables there is a list of batched lookups - # TORCH-MLIR - # We are passing all the embeddings as arguments for easy parsing. - - ly = [] - for k, sparse_index_group_batch in enumerate(lS_i): - sparse_offset_group_batch = lS_o[k] - - # embedding lookup - # We are using EmbeddingBag, which implicitly uses sum operator. - # The embeddings are represented as tall matrices, with sum - # happening vertically across 0 axis, resulting in a row vector - # E = emb_l[k] - - if v_W_l[k] is not None: - per_sample_weights = v_W_l[k].gather( - 0, sparse_index_group_batch - ) - else: - per_sample_weights = None - - E = emb_l[k] - V = E( - sparse_index_group_batch, - sparse_offset_group_batch, - per_sample_weights=per_sample_weights, - ) - - ly.append(V) - - return ly - - def interact_features(self, x, ly): - if self.arch_interaction_op == "dot": - # concatenate dense and sparse features - (batch_size, d) = x.shape - T = torch.cat([x] + ly, dim=1).view((batch_size, -1, d)) - # perform a dot product - Z = torch.bmm(T, torch.transpose(T, 1, 2)) - # append dense feature with the interactions (into a row vector) - # approach 1: all - # Zflat = Z.view((batch_size, -1)) - # approach 2: unique - _, ni, nj = Z.shape - # approach 1: tril_indices - # offset = 0 if self.arch_interaction_itself else -1 - # li, lj = torch.tril_indices(ni, nj, offset=offset) - # approach 2: custom - offset = 1 if self.arch_interaction_itself else 0 - li = torch.tensor( - [i for i in range(ni) for j in range(i + offset)] - ) - lj = torch.tensor( - [j for i in range(nj) for j in range(i + offset)] - ) - Zflat = Z[:, li, lj] - # concatenate dense features and interactions - R = torch.cat([x] + [Zflat], dim=1) - elif self.arch_interaction_op == "cat": - # concatenation features (into a row vector) - R = torch.cat([x] + ly, dim=1) - else: - sys.exit( - "ERROR: --arch-interaction-op=" - + self.arch_interaction_op - + " is not supported" - ) - - return R - - def forward(self, dense_x, lS_o, *lS_i): - return self.sequential_forward(dense_x, lS_o, lS_i) - - def sequential_forward(self, dense_x, lS_o, lS_i): - # process dense features (using bottom mlp), resulting in a row vector - x = self.apply_mlp(dense_x, self.bot_l) - # debug prints - # print("intermediate") - # print(x.detach().cpu().numpy()) - - # process sparse features(using embeddings), resulting in a list of row vectors - ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l) - # for y in ly: - # print(y.detach().cpu().numpy()) - - # interact features (dense and sparse) - z = self.interact_features(x, ly) - # print(z.detach().cpu().numpy()) - - # obtain probability of a click (using top mlp) - p = self.apply_mlp(z, self.top_l) - - # # clamp output if needed - # if 0.0 < self.loss_threshold and self.loss_threshold < 1.0: - # z = torch.clamp(p, min=self.loss_threshold, max=(1.0 - self.loss_threshold)) - # else: - # z = p - - return p - - -def dash_separated_ints(value): - vals = value.split("-") - for val in vals: - try: - int(val) - except ValueError: - raise argparse.ArgumentTypeError( - "%s is not a valid dash separated list of ints" % value - ) - - return value - - -# model related parameters -parser = argparse.ArgumentParser( - description="Train Deep Learning Recommendation Model (DLRM)" -) -parser.add_argument("--arch-sparse-feature-size", type=int, default=2) -parser.add_argument( - "--arch-embedding-size", type=dash_separated_ints, default="4-3-2" -) -# j will be replaced with the table number -parser.add_argument( - "--arch-mlp-bot", type=dash_separated_ints, default="4-3-2" -) -parser.add_argument( - "--arch-mlp-top", type=dash_separated_ints, default="8-2-1" -) -parser.add_argument( - "--arch-interaction-op", type=str, choices=["dot", "cat"], default="dot" -) -parser.add_argument( - "--arch-interaction-itself", action="store_true", default=False -) -parser.add_argument("--weighted-pooling", type=str, default=None) - -args = parser.parse_args() - -ln_bot = np.fromstring(args.arch_mlp_bot, dtype=int, sep="-") -ln_top = np.fromstring(args.arch_mlp_top, dtype=int, sep="-") -m_den = ln_bot[0] -ln_emb = np.fromstring(args.arch_embedding_size, dtype=int, sep="-") -m_spa = args.arch_sparse_feature_size -ln_emb = np.asarray(ln_emb) -num_fea = ln_emb.size + 1 # num sparse + num dense features - - -# Initialize the model. -dlrm_model = DLRM_Net( - m_spa=m_spa, - ln_emb=ln_emb, - ln_bot=ln_bot, - ln_top=ln_top, - arch_interaction_op=args.arch_interaction_op, -) - - -# Inputs to the model. -dense_inp = torch.tensor([[0.6965, 0.2861, 0.2269, 0.5513]]) -vs0 = torch.tensor([[0], [0], [0]], dtype=torch.int64) -vsi = torch.tensor([1, 2, 3]), torch.tensor([1]), torch.tensor([1]) - -input_dlrm = (dense_inp, vs0, *vsi) - -golden_output = dlrm_model(dense_inp, vs0, *vsi) - -mlir_importer = SharkImporter( - dlrm_model, - input_dlrm, - frontend="torch", -) - -(dlrm_mlir, func_name), inputs, golden_out = mlir_importer.import_debug( - tracing_required=True -) - -shark_module = SharkInference( - dlrm_mlir, device="vulkan", mlir_dialect="linalg" -) -shark_module.compile() -result = shark_module.forward(input_dlrm) -np.testing.assert_allclose( - golden_output.detach().numpy(), result, rtol=1e-02, atol=1e-03 -) - - -# Verified via torch-mlir. -# import torch_mlir -# from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend - - -# module = torch_mlir.compile( -# dlrm_model, inputs, use_tracing=True, output_type="linalg-on-tensors" -# ) -# backend = refbackend.RefBackendLinalgOnTensorsBackend() -# compiled = backend.compile(module) -# jit_module = backend.load(compiled) - -# dense_numpy = dense_inp.numpy() -# vs0_numpy = vs0.numpy() -# vsi_numpy = [inp.numpy() for inp in vsi] - -# numpy_inp = (dense_numpy, vs0_numpy, *vsi_numpy) - -# print(jit_module.forward(*numpy_inp)) diff --git a/shark/examples/shark_inference/sparse_arch.py b/shark/examples/shark_inference/sparse_arch.py deleted file mode 100644 index 6cf6cddf82..0000000000 --- a/shark/examples/shark_inference/sparse_arch.py +++ /dev/null @@ -1,311 +0,0 @@ -import torch -from torch import nn -from torchrec.datasets.utils import Batch -from torchrec.modules.crossnet import LowRankCrossNet -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor -from torchrec.modules.embedding_configs import EmbeddingBagConfig -from torchrec.modules.embedding_modules import EmbeddingBagCollection -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor -from typing import Dict, List, Optional, Tuple -from torchrec.models.dlrm import ( - choose, - DenseArch, - DLRM, - InteractionArch, - SparseArch, - OverArch, -) -from shark.shark_inference import SharkInference -from shark.shark_importer import SharkImporter -import numpy as np - -torch.manual_seed(0) - -np.random.seed(0) - - -def calculate_offsets(tensor_list, prev_values, prev_offsets): - offset_init = 0 - offset_list = [] - values_list = [] - - if prev_offsets != None: - offset_init = prev_values.shape[-1] - for tensor in tensor_list: - offset_list.append(offset_init) - offset_init += tensor.shape[0] - - concatendated_tensor_list = torch.cat(tensor_list) - - if prev_values != None: - concatendated_tensor_list = torch.cat( - [prev_values, concatendated_tensor_list] - ) - - concatenated_offsets = torch.tensor(offset_list) - - if prev_offsets != None: - concatenated_offsets = torch.cat([prev_offsets, concatenated_offsets]) - - return concatendated_tensor_list, concatenated_offsets - - -# Have to make combined_keys as dict as to which embedding bags they -# point to. {f1: 0, f3: 0, f2: 1} -# The result will be a triple containing values, indices and pointer tensor. -def to_list(key_jagged, combined_keys): - key_jagged_dict = key_jagged.to_dict() - combined_list = [] - - for key in combined_keys: - prev_values, prev_offsets = calculate_offsets( - key_jagged_dict[key].to_dense(), None, None - ) - print(prev_values) - print(prev_offsets) - combined_list.append(prev_values) - combined_list.append(prev_offsets) - combined_list.append(torch.tensor(combined_keys[key])) - - return combined_list - - -class SparseArchShark(nn.Module): - def create_emb(self, embedding_dim, num_embeddings_list): - embedding_list = nn.ModuleList() - for i in range(0, num_embeddings_list.size): - num_embeddings = num_embeddings_list[i] - EE = nn.EmbeddingBag(num_embeddings, embedding_dim, mode="sum") - W = np.random.uniform( - low=-np.sqrt(1 / num_embeddings), - high=np.sqrt(1 / num_embeddings), - size=(num_embeddings, embedding_dim), - ).astype(np.float32) - EE.weight.data = torch.tensor(W, requires_grad=True) - embedding_list.append(EE) - return embedding_list - - def __init__( - self, - embedding_dim, - total_features, - num_embeddings_list, - ): - super(SparseArchShark, self).__init__() - self.embedding_dim = embedding_dim - self.num_features = total_features - self.embedding_list = self.create_emb( - embedding_dim, num_embeddings_list - ) - - def forward(self, *batched_inputs): - concatenated_list = [] - input_enum, embedding_enum = 0, 0 - - for k in range(len(batched_inputs) // 3): - values = batched_inputs[input_enum] - input_enum += 1 - offsets = batched_inputs[input_enum] - input_enum += 1 - embedding_pointer = int(batched_inputs[input_enum]) - input_enum += 1 - - E = self.embedding_list[embedding_pointer] - V = E(values, offsets) - concatenated_list.append(V) - - return torch.cat(concatenated_list, dim=1).reshape( - -1, self.num_features, self.embedding_dim - ) - - -def test_sparse_arch() -> None: - D = 3 - eb1_config = EmbeddingBagConfig( - name="t1", - embedding_dim=D, - num_embeddings=10, - feature_names=["f1", "f3"], - ) - eb2_config = EmbeddingBagConfig( - name="t2", - embedding_dim=D, - num_embeddings=10, - feature_names=["f2"], - ) - - ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) - - w1 = ebc.embedding_bags["t1"].weight - w2 = ebc.embedding_bags["t2"].weight - - sparse_arch = SparseArch(ebc) - - keys = ["f1", "f2", "f3", "f4", "f5"] - offsets = torch.tensor([0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 19]) - features = KeyedJaggedTensor.from_offsets_sync( - keys=keys, - values=torch.tensor( - [1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 3] - ), - offsets=offsets, - ) - sparse_archi = SparseArchShark(D, 3, np.array([10, 10])) - sparse_archi.embedding_list[0].weight = w1 - sparse_archi.embedding_list[1].weight = w2 - inputs = to_list(features, {"f1": 0, "f3": 0, "f2": 1}) - - test_results = sparse_archi(*inputs) - sparse_features = sparse_arch(features) - - torch.allclose( - sparse_features, - test_results, - rtol=1e-4, - atol=1e-4, - ) - - -test_sparse_arch() - - -class DLRMShark(nn.Module): - def __init__( - self, - embedding_dim, - total_features, - num_embeddings_list, - dense_in_features: int, - dense_arch_layer_sizes: List[int], - over_arch_layer_sizes: List[int], - ) -> None: - super().__init__() - - self.sparse_arch: SparseArchShark = SparseArchShark( - embedding_dim, total_features, num_embeddings_list - ) - num_sparse_features: int = total_features - - self.dense_arch = DenseArch( - in_features=dense_in_features, - layer_sizes=dense_arch_layer_sizes, - ) - - self.inter_arch = InteractionArch( - num_sparse_features=num_sparse_features, - ) - - over_in_features: int = ( - embedding_dim - + choose(num_sparse_features, 2) - + num_sparse_features - ) - - self.over_arch = OverArch( - in_features=over_in_features, - layer_sizes=over_arch_layer_sizes, - ) - - def forward( - self, dense_features: torch.Tensor, *sparse_features - ) -> torch.Tensor: - embedded_dense = self.dense_arch(dense_features) - embedded_sparse = self.sparse_arch(*sparse_features) - concatenated_dense = self.inter_arch( - dense_features=embedded_dense, sparse_features=embedded_sparse - ) - logits = self.over_arch(concatenated_dense) - return logits - - -def test_dlrm() -> None: - B = 2 - D = 8 - dense_in_features = 100 - - eb1_config = EmbeddingBagConfig( - name="t1", - embedding_dim=D, - num_embeddings=100, - feature_names=["f1", "f3"], - ) - eb2_config = EmbeddingBagConfig( - name="t2", - embedding_dim=D, - num_embeddings=100, - feature_names=["f2"], - ) - - ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) - - sparse_features = KeyedJaggedTensor.from_offsets_sync( - keys=["f1", "f3", "f2"], - values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 3]), - offsets=torch.tensor([0, 2, 4, 6, 8, 10, 11]), - ) - ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) - sparse_nn = DLRM( - embedding_bag_collection=ebc, - dense_in_features=dense_in_features, - dense_arch_layer_sizes=[20, D], - over_arch_layer_sizes=[5, 1], - ) - sparse_nn_nod = DLRMShark( - embedding_dim=8, - total_features=3, - num_embeddings_list=np.array([100, 100]), - dense_in_features=dense_in_features, - dense_arch_layer_sizes=[20, D], - over_arch_layer_sizes=[5, 1], - ) - - dense_features = torch.rand((B, dense_in_features)) - - x = to_list(sparse_features, {"f1": 0, "f3": 0, "f2": 1}) - - w1 = ebc.embedding_bags["t1"].weight - w2 = ebc.embedding_bags["t2"].weight - - sparse_nn_nod.sparse_arch.embedding_list[0].weight = w1 - sparse_nn_nod.sparse_arch.embedding_list[1].weight = w2 - - sparse_nn_nod.dense_arch.load_state_dict(sparse_nn.dense_arch.state_dict()) - sparse_nn_nod.inter_arch.load_state_dict(sparse_nn.inter_arch.state_dict()) - sparse_nn_nod.over_arch.load_state_dict(sparse_nn.over_arch.state_dict()) - - logits = sparse_nn( - dense_features=dense_features, - sparse_features=sparse_features, - ) - logits_nod = sparse_nn_nod(dense_features, *x) - - # print(logits) - # print(logits_nod) - - # Import the module and print. - mlir_importer = SharkImporter( - sparse_nn_nod, - (dense_features, *x), - frontend="torch", - ) - - (dlrm_mlir, func_name), inputs, golden_out = mlir_importer.import_debug( - tracing_required=True - ) - - shark_module = SharkInference( - dlrm_mlir, device="cpu", mlir_dialect="linalg" - ) - shark_module.compile() - result = shark_module.forward(inputs) - np.testing.assert_allclose(golden_out, result, rtol=1e-02, atol=1e-03) - - torch.allclose( - logits, - logits_nod, - rtol=1e-4, - atol=1e-4, - ) - - -test_dlrm() diff --git a/shark/examples/shark_inference/t5_tf.py b/shark/examples/shark_inference/t5_tf.py deleted file mode 100644 index e72004a822..0000000000 --- a/shark/examples/shark_inference/t5_tf.py +++ /dev/null @@ -1,35 +0,0 @@ -from PIL import Image -import requests - -from transformers import T5Tokenizer, TFT5Model -import tensorflow as tf -from shark.shark_inference import SharkInference - -# Create a set of inputs -t5_inputs = [ - tf.TensorSpec(shape=[1, 10], dtype=tf.int32), - tf.TensorSpec(shape=[1, 10], dtype=tf.int32), -] - - -class T5Module(tf.Module): - def __init__(self): - super(T5Module, self).__init__() - self.m = TFT5Model.from_pretrained("t5-small") - self.m.predict = lambda x, y: self.m(input_ids=x, decoder_input_ids=y) - - @tf.function(input_signature=t5_inputs, jit_compile=True) - def forward(self, input_ids, decoder_input_ids): - return self.m.predict(input_ids, decoder_input_ids) - - -if __name__ == "__main__": - # Prepping Data - tokenizer = T5Tokenizer.from_pretrained("t5-small") - text = "I love the distilled version of models." - inputs = tokenizer(text, return_tensors="tf").input_ids - - shark_module = SharkInference(T5Module(), (inputs, inputs)) - shark_module.set_frontend("tensorflow") - shark_module.compile() - print(shark_module.forward((inputs, inputs))) diff --git a/shark/examples/shark_inference/torch_vision_models_script.py b/shark/examples/shark_inference/torch_vision_models_script.py deleted file mode 100644 index ae66883b25..0000000000 --- a/shark/examples/shark_inference/torch_vision_models_script.py +++ /dev/null @@ -1,43 +0,0 @@ -import torch -import torchvision.models as models -from shark.shark_inference import SharkInference - - -class VisionModule(torch.nn.Module): - def __init__(self, model): - super().__init__() - self.model = model - self.train(False) - - def forward(self, input): - return self.model.forward(input) - - -input = torch.randn(1, 3, 224, 224) - -## The vision models present here: https://pytorch.org/vision/stable/models.html -vision_models_list = [ - models.resnet18(pretrained=True), - models.alexnet(pretrained=True), - models.vgg16(pretrained=True), - models.squeezenet1_0(pretrained=True), - models.densenet161(pretrained=True), - models.inception_v3(pretrained=True), - models.shufflenet_v2_x1_0(pretrained=True), - models.mobilenet_v2(pretrained=True), - models.mobilenet_v3_small(pretrained=True), - models.resnext50_32x4d(pretrained=True), - models.wide_resnet50_2(pretrained=True), - models.mnasnet1_0(pretrained=True), - models.efficientnet_b0(pretrained=True), - models.regnet_y_400mf(pretrained=True), - models.regnet_x_400mf(pretrained=True), -] - -for i, vision_model in enumerate(vision_models_list): - shark_module = SharkInference( - VisionModule(vision_model), - (input,), - ) - shark_module.compile() - shark_module.forward((input,)) diff --git a/shark/examples/shark_inference/unet_script.py b/shark/examples/shark_inference/unet_script.py deleted file mode 100644 index 01b5eebe02..0000000000 --- a/shark/examples/shark_inference/unet_script.py +++ /dev/null @@ -1,39 +0,0 @@ -import torch -import numpy as np -from shark.shark_inference import SharkInference -from shark.shark_importer import SharkImporter - - -class UnetModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.model = torch.hub.load( - "mateuszbuda/brain-segmentation-pytorch", - "unet", - in_channels=3, - out_channels=1, - init_features=32, - pretrained=True, - ) - self.model.eval() - - def forward(self, input): - return self.model(input) - - -input = torch.randn(1, 3, 224, 224) - -mlir_importer = SharkImporter( - UnetModule(), - (input,), - frontend="torch", -) - -(vision_mlir, func_name), inputs, golden_out = mlir_importer.import_debug( - tracing_required=False -) - -shark_module = SharkInference(vision_mlir, mlir_dialect="linalg") -shark_module.compile() -result = shark_module.forward((input,)) -np.testing.assert_allclose(golden_out, result, rtol=1e-02, atol=1e-03) diff --git a/shark/examples/shark_inference/upscaler/main.py b/shark/examples/shark_inference/upscaler/main.py deleted file mode 100644 index e5ad11ac28..0000000000 --- a/shark/examples/shark_inference/upscaler/main.py +++ /dev/null @@ -1,21 +0,0 @@ -import requests -from PIL import Image -from io import BytesIO -from pipeline_shark_stable_diffusion_upscale import ( - SharkStableDiffusionUpscalePipeline, -) -import torch - -model_id = "stabilityai/stable-diffusion-x4-upscaler" -pipeline = SharkStableDiffusionUpscalePipeline(model_id) - -# let's download an image -url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale/low_res_cat.png" -response = requests.get(url) -low_res_img = Image.open(BytesIO(response.content)).convert("RGB") -low_res_img = low_res_img.resize((128, 128)) - -prompt = "a white cat" - -upscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0] -upscaled_image.save("upsampled_cat.png") diff --git a/shark/examples/shark_inference/upscaler/model_wrappers.py b/shark/examples/shark_inference/upscaler/model_wrappers.py deleted file mode 100644 index 1de747c4a2..0000000000 --- a/shark/examples/shark_inference/upscaler/model_wrappers.py +++ /dev/null @@ -1,98 +0,0 @@ -from diffusers import AutoencoderKL, UNet2DConditionModel -from transformers import CLIPTextModel -from utils import compile_through_fx -import torch - -model_id = "stabilityai/stable-diffusion-x4-upscaler" - -model_input = { - "clip": (torch.randint(1, 2, (1, 77)),), - "vae": (torch.randn(1, 4, 128, 128),), - "unet": ( - torch.randn(2, 7, 128, 128), # latents - torch.tensor([1]).to(torch.float32), # timestep - torch.randn(2, 77, 1024), # embedding - torch.randn(2).to(torch.int64), # noise_level - ), -} - - -def get_clip_mlir(model_name="clip_text", extra_args=[]): - text_encoder = CLIPTextModel.from_pretrained( - model_id, - subfolder="text_encoder", - ) - - class CLIPText(torch.nn.Module): - def __init__(self): - super().__init__() - self.text_encoder = text_encoder - - def forward(self, input): - return self.text_encoder(input)[0] - - clip_model = CLIPText() - shark_clip = compile_through_fx( - clip_model, - model_input["clip"], - model_name=model_name, - extra_args=extra_args, - ) - return shark_clip - - -def get_vae_mlir(model_name="vae", extra_args=[]): - class VaeModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.vae = AutoencoderKL.from_pretrained( - model_id, - subfolder="vae", - ) - - def forward(self, input): - x = self.vae.decode(input, return_dict=False)[0] - return x - - vae = VaeModel() - shark_vae = compile_through_fx( - vae, - model_input["vae"], - model_name=model_name, - extra_args=extra_args, - ) - return shark_vae - - -def get_unet_mlir(model_name="unet", extra_args=[]): - class UnetModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.unet = UNet2DConditionModel.from_pretrained( - model_id, - subfolder="unet", - ) - self.in_channels = self.unet.in_channels - self.train(False) - - def forward(self, latent, timestep, text_embedding, noise_level): - unet_out = self.unet.forward( - latent, - timestep, - text_embedding, - noise_level, - return_dict=False, - )[0] - return unet_out - - unet = UnetModel() - f16_input_mask = (True, True, True, False) - shark_unet = compile_through_fx( - unet, - model_input["unet"], - model_name=model_name, - is_f16=True, - f16_input_mask=f16_input_mask, - extra_args=extra_args, - ) - return shark_unet diff --git a/shark/examples/shark_inference/upscaler/opt_params.py b/shark/examples/shark_inference/upscaler/opt_params.py deleted file mode 100644 index d293e89e65..0000000000 --- a/shark/examples/shark_inference/upscaler/opt_params.py +++ /dev/null @@ -1,48 +0,0 @@ -import sys -from model_wrappers import ( - get_vae_mlir, - get_unet_mlir, - get_clip_mlir, -) -from upscaler_args import args -from utils import get_shark_model - -BATCH_SIZE = len(args.prompts) -if BATCH_SIZE != 1: - sys.exit("Only batch size 1 is supported.") - - -unet_flag = [ - "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))" -] - -vae_flag = [ - "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-convert-conv-nchw-to-nhwc,iree-preprocessing-pad-linalg-ops{pad-size=16}))" -] - -clip_flag = [ - "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))" -] - -bucket = "gs://shark_tank/stable_diffusion/" - - -def get_unet(): - model_name = "upscaler_unet" - if args.import_mlir: - return get_unet_mlir(model_name, unet_flag) - return get_shark_model(bucket, model_name, unet_flag) - - -def get_vae(): - model_name = "upscaler_vae" - if args.import_mlir: - return get_vae_mlir(model_name, vae_flag) - return get_shark_model(bucket, model_name, vae_flag) - - -def get_clip(): - model_name = "upscaler_clip" - if args.import_mlir: - return get_clip_mlir(model_name, clip_flag) - return get_shark_model(bucket, model_name, clip_flag) diff --git a/shark/examples/shark_inference/upscaler/pipeline_shark_stable_diffusion_upscale.py b/shark/examples/shark_inference/upscaler/pipeline_shark_stable_diffusion_upscale.py deleted file mode 100644 index d1aeb53897..0000000000 --- a/shark/examples/shark_inference/upscaler/pipeline_shark_stable_diffusion_upscale.py +++ /dev/null @@ -1,489 +0,0 @@ -import inspect -from typing import Callable, List, Optional, Union - -import numpy as np -import torch - -import PIL -from PIL import Image -from diffusers.utils import is_accelerate_available -from transformers import CLIPTextModel, CLIPTokenizer -from diffusers import AutoencoderKL, UNet2DConditionModel -from diffusers import ( - DDIMScheduler, - DDPMScheduler, - LMSDiscreteScheduler, - PNDMScheduler, -) -from diffusers import logging -from diffusers.pipeline_utils import ImagePipelineOutput -from opt_params import get_unet, get_vae, get_clip -from tqdm.auto import tqdm - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -def preprocess(image): - if isinstance(image, torch.Tensor): - return image - elif isinstance(image, PIL.Image.Image): - image = [image] - - if isinstance(image[0], PIL.Image.Image): - w, h = image[0].size - w, h = map( - lambda x: x - x % 64, (w, h) - ) # resize to integer multiple of 64 - - image = [np.array(i.resize((w, h)))[None, :] for i in image] - image = np.concatenate(image, axis=0) - image = np.array(image).astype(np.float32) / 255.0 - image = image.transpose(0, 3, 1, 2) - image = 2.0 * image - 1.0 - image = torch.from_numpy(image) - elif isinstance(image[0], torch.Tensor): - image = torch.cat(image, dim=0) - return image - - -def shark_run_wrapper(model, *args): - np_inputs = tuple([x.detach().numpy() for x in args]) - outputs = model("forward", np_inputs) - return torch.from_numpy(outputs) - - -class SharkStableDiffusionUpscalePipeline: - def __init__( - self, - model_id, - ): - self.tokenizer = CLIPTokenizer.from_pretrained( - model_id, subfolder="tokenizer" - ) - self.low_res_scheduler = DDPMScheduler.from_pretrained( - model_id, - subfolder="scheduler", - ) - self.scheduler = DDIMScheduler.from_pretrained( - model_id, - subfolder="scheduler", - ) - self.vae = get_vae() - self.unet = get_unet() - self.text_encoder = get_clip() - self.max_noise_level = (350,) - self._execution_device = "cpu" - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - ): - r""" - Encodes the prompt into text encoder hidden states. - Args: - prompt (`str` or `list(int)`): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer( - prompt, padding="longest", return_tensors="pt" - ).input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[ - -1 - ] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - # if ( - # hasattr(self.text_encoder.config, "use_attention_mask") - # and self.text_encoder.config.use_attention_mask - # ): - # attention_mask = text_inputs.attention_mask.to(device) - # else: - # attention_mask = None - - text_embeddings = shark_run_wrapper( - self.text_encoder, text_input_ids.to(device) - ) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - bs_embed, seq_len, _ = text_embeddings.shape - text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) - text_embeddings = text_embeddings.view( - bs_embed * num_images_per_prompt, seq_len, -1 - ) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - max_length = text_input_ids.shape[-1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - # if ( - # hasattr(self.text_encoder.config, "use_attention_mask") - # and self.text_encoder.config.use_attention_mask - # ): - # attention_mask = uncond_input.attention_mask.to(device) - # else: - # attention_mask = None - - uncond_embeddings = shark_run_wrapper( - self.text_encoder, - uncond_input.input_ids.to(device), - ) - uncond_embeddings = uncond_embeddings - - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = uncond_embeddings.shape[1] - uncond_embeddings = uncond_embeddings.repeat( - 1, num_images_per_prompt, 1 - ) - uncond_embeddings = uncond_embeddings.view( - batch_size * num_images_per_prompt, seq_len, -1 - ) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - - return text_embeddings - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set( - inspect.signature(self.scheduler.step).parameters.keys() - ) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set( - inspect.signature(self.scheduler.step).parameters.keys() - ) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents with 0.18215->0.08333 - def decode_latents(self, latents): - latents = 1 / 0.08333 * latents - image = shark_run_wrapper(self.vae, latents) - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - def check_inputs(self, prompt, image, noise_level, callback_steps): - if not isinstance(prompt, str) and not isinstance(prompt, list): - raise ValueError( - f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" - ) - - if ( - not isinstance(image, torch.Tensor) - and not isinstance(image, PIL.Image.Image) - and not isinstance(image, list) - ): - raise ValueError( - f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is {type(image)}" - ) - - # verify batch size of prompt and image are same if image is a list or tensor - if isinstance(image, list) or isinstance(image, torch.Tensor): - if isinstance(prompt, str): - batch_size = 1 - else: - batch_size = len(prompt) - if isinstance(image, list): - image_batch_size = len(image) - else: - image_batch_size = image.shape[0] - if batch_size != image_batch_size: - raise ValueError( - f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}." - " Please make sure that passed `prompt` matches the batch size of `image`." - ) - - @staticmethod - def numpy_to_pil(images): - """ - Convert a numpy image or a batch of images to a PIL image. - """ - if images.ndim == 3: - images = images[None, ...] - images = (images * 255).round().astype("uint8") - if images.shape[-1] == 1: - # special case for grayscale (single channel) images - pil_images = [ - Image.fromarray(image.squeeze(), mode="L") for image in images - ] - else: - pil_images = [Image.fromarray(image) for image in images] - - return pil_images - - def prepare_latents( - self, - batch_size, - num_channels_latents, - height, - width, - dtype, - device, - generator, - latents=None, - ): - shape = (batch_size, num_channels_latents, height, width) - if latents is None: - if device == "mps": - # randn does not work reproducibly on mps - latents = torch.randn( - shape, generator=generator, device="cpu", dtype=dtype - ).to(device) - else: - latents = torch.randn( - shape, generator=generator, device=device, dtype=dtype - ) - else: - if latents.shape != shape: - raise ValueError( - f"Unexpected latents shape, got {latents.shape}, expected {shape}" - ) - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]], - image: Union[ - torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image] - ], - num_inference_steps: int = 75, - guidance_scale: float = 9.0, - noise_level: int = 20, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[ - Union[torch.Generator, List[torch.Generator]] - ] = None, - latents: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[ - Callable[[int, int, torch.FloatTensor], None] - ] = None, - callback_steps: Optional[int] = 1, - ): - # 1. Check inputs - self.check_inputs(prompt, image, noise_level, callback_steps) - - # 2. Define call parameters - batch_size = 1 if isinstance(prompt, str) else len(prompt) - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - # 3. Encode input prompt - text_embeddings = self._encode_prompt( - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - ) - - # 4. Preprocess image - image = preprocess(image) - image = image.to(dtype=text_embeddings.dtype, device=device) - - # 5. set timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - # 5. Add noise to image - noise_level = torch.tensor( - [noise_level], dtype=torch.long, device=device - ) - if device == "mps": - # randn does not work reproducibly on mps - noise = torch.randn( - image.shape, - generator=generator, - device="cpu", - dtype=text_embeddings.dtype, - ).to(device) - else: - noise = torch.randn( - image.shape, - generator=generator, - device=device, - dtype=text_embeddings.dtype, - ) - image = self.low_res_scheduler.add_noise(image, noise, noise_level) - - batch_multiplier = 2 if do_classifier_free_guidance else 1 - image = torch.cat([image] * batch_multiplier * num_images_per_prompt) - noise_level = torch.cat([noise_level] * image.shape[0]) - - # 6. Prepare latent variables - height, width = image.shape[2:] - # num_channels_latents = self.vae.config.latent_channels - num_channels_latents = 4 - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - text_embeddings.dtype, - device, - generator, - latents, - ) - - # 7. Check that sizes of image and latents match - num_channels_image = image.shape[1] - # if ( - # num_channels_latents + num_channels_image - # != self.unet.config.in_channels - # ): - # raise ValueError( - # f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" - # f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - # f" `num_channels_image`: {num_channels_image} " - # f" = {num_channels_latents+num_channels_image}. Please verify the config of" - # " `pipeline.unet` or your `image` input." - # ) - - # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 9. Denoising loop - num_warmup_steps = ( - len(timesteps) - num_inference_steps * self.scheduler.order - ) - for i, t in tqdm(enumerate(timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = ( - torch.cat([latents] * 2) - if do_classifier_free_guidance - else latents - ) - - # concat latents, mask, masked_image_latents in the channel dimension - latent_model_input = self.scheduler.scale_model_input( - latent_model_input, t - ) - latent_model_input = torch.cat([latent_model_input, image], dim=1) - - timestep = torch.tensor([t]).to(torch.float32) - - # predict the noise residual - noise_pred = shark_run_wrapper( - self.unet, - latent_model_input.half(), - timestep, - text_embeddings.half(), - noise_level, - ) - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step( - noise_pred, t, latents, **extra_step_kwargs - ).prev_sample - - # # call the callback, if provided - # if i == len(timesteps) - 1 or ( - # (i + 1) > num_warmup_steps - # and (i + 1) % self.scheduler.order == 0 - # ): - # progress_bar.update() - # if callback is not None and i % callback_steps == 0: - # callback(i, t, latents) - - # 10. Post-processing - # make sure the VAE is in float32 mode, as it overflows in float16 - # self.vae.to(dtype=torch.float32) - image = self.decode_latents(latents.float()) - - # 11. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image,) - - return ImagePipelineOutput(images=image) diff --git a/shark/examples/shark_inference/upscaler/upscaler_args.py b/shark/examples/shark_inference/upscaler/upscaler_args.py deleted file mode 100644 index f91c030658..0000000000 --- a/shark/examples/shark_inference/upscaler/upscaler_args.py +++ /dev/null @@ -1,98 +0,0 @@ -import argparse - -p = argparse.ArgumentParser( - description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter -) - -############################################################################## -### Stable Diffusion Params -############################################################################## - -p.add_argument( - "--prompts", - nargs="+", - default=["cyberpunk forest by Salvador Dali"], - help="text of which images to be generated.", -) - -p.add_argument( - "--negative-prompts", - nargs="+", - default=[""], - help="text you don't want to see in the generated image.", -) - -p.add_argument( - "--steps", - type=int, - default=50, - help="the no. of steps to do the sampling.", -) - -p.add_argument( - "--seed", - type=int, - default=42, - help="the seed to use.", -) - -p.add_argument( - "--guidance_scale", - type=float, - default=7.5, - help="the value to be used for guidance scaling.", -) - -############################################################################## -### Model Config and Usage Params -############################################################################## - -p.add_argument( - "--device", type=str, default="vulkan", help="device to run the model." -) - -p.add_argument( - "--precision", type=str, default="fp16", help="precision to run the model." -) - -p.add_argument( - "--import_mlir", - default=False, - action=argparse.BooleanOptionalAction, - help="imports the model from torch module to shark_module otherwise downloads the model from shark_tank.", -) - -p.add_argument( - "--load_vmfb", - default=True, - action=argparse.BooleanOptionalAction, - help="attempts to load the model from a precompiled flatbuffer and compiles + saves it if not found.", -) - -p.add_argument( - "--save_vmfb", - default=False, - action=argparse.BooleanOptionalAction, - help="saves the compiled flatbuffer to the local directory", -) - -############################################################################## -### IREE - Vulkan supported flags -############################################################################## - -p.add_argument( - "--iree-vulkan-target-triple", - type=str, - default="", - help="Specify target triple for vulkan", -) - -p.add_argument( - "--vulkan_debug_utils", - default=False, - action=argparse.BooleanOptionalAction, - help="Profiles vulkan device and collects the .rdc info", -) - - -args = p.parse_args() diff --git a/shark/examples/shark_inference/upscaler/utils.py b/shark/examples/shark_inference/upscaler/utils.py deleted file mode 100644 index 4531d4e816..0000000000 --- a/shark/examples/shark_inference/upscaler/utils.py +++ /dev/null @@ -1,230 +0,0 @@ -import os -import torch -from shark.shark_inference import SharkInference -from upscaler_args import args -from shark.shark_importer import import_with_fx -from shark.iree_utils.vulkan_utils import ( - set_iree_vulkan_runtime_flags, - get_vulkan_target_triple, - get_iree_vulkan_runtime_flags, -) - - -def _compile_module(shark_module, model_name, extra_args=[]): - if args.load_vmfb or args.save_vmfb: - device = ( - args.device - if "://" not in args.device - else "-".join(args.device.split("://")) - ) - extended_name = "{}_{}".format(model_name, device) - vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb") - if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb: - print(f"loading existing vmfb from: {vmfb_path}") - shark_module.load_module(vmfb_path, extra_args=extra_args) - else: - if args.save_vmfb: - print("Saving to {}".format(vmfb_path)) - else: - print( - "No vmfb found. Compiling and saving to {}".format( - vmfb_path - ) - ) - path = shark_module.save_module( - os.getcwd(), extended_name, extra_args - ) - shark_module.load_module(path, extra_args=extra_args) - else: - shark_module.compile(extra_args) - return shark_module - - -# Downloads the model from shark_tank and returns the shark_module. -def get_shark_model(tank_url, model_name, extra_args=[]): - from shark.shark_downloader import download_model - from shark.parser import shark_args - - # Set local shark_tank cache directory. - # shark_args.local_tank_cache = args.local_tank_cache - - mlir_model, func_name, inputs, golden_out = download_model( - model_name, - tank_url=tank_url, - frontend="torch", - ) - shark_module = SharkInference( - mlir_model, device=args.device, mlir_dialect="linalg" - ) - return _compile_module(shark_module, model_name, extra_args) - - -# Converts the torch-module into a shark_module. -def compile_through_fx( - model, inputs, model_name, is_f16=False, f16_input_mask=None, extra_args=[] -): - mlir_module, func_name = import_with_fx( - model, inputs, is_f16, f16_input_mask - ) - shark_module = SharkInference( - mlir_module, - device=args.device, - mlir_dialect="linalg", - ) - - return _compile_module(shark_module, model_name, extra_args) - - -def set_iree_runtime_flags(): - vulkan_runtime_flags = get_iree_vulkan_runtime_flags() - if args.enable_rgp: - vulkan_runtime_flags += [ - f"--enable_rgp=true", - f"--vulkan_debug_utils=true", - ] - set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags) - - -def get_all_devices(driver_name): - """ - Inputs: driver_name - Returns a list of all the available devices for a given driver sorted by - the iree path names of the device as in --list_devices option in iree. - """ - from iree.runtime import get_driver - - driver = get_driver(driver_name) - device_list_src = driver.query_available_devices() - device_list_src.sort(key=lambda d: d["path"]) - return device_list_src - - -def get_device_mapping(driver, key_combination=3): - """This method ensures consistent device ordering when choosing - specific devices for execution - Args: - driver (str): execution driver (vulkan, cuda, rocm, etc) - key_combination (int, optional): choice for mapping value for device name. - 1 : path - 2 : name - 3 : (name, path) - Defaults to 3. - Returns: - dict: map to possible device names user can input mapped to desired combination of name/path. - """ - from shark.iree_utils._common import iree_device_map - - driver = iree_device_map(driver) - device_list = get_all_devices(driver) - device_map = dict() - - def get_output_value(dev_dict): - if key_combination == 1: - return f"{driver}://{dev_dict['path']}" - if key_combination == 2: - return dev_dict["name"] - if key_combination == 3: - return (dev_dict["name"], f"{driver}://{dev_dict['path']}") - - # mapping driver name to default device (driver://0) - device_map[f"{driver}"] = get_output_value(device_list[0]) - for i, device in enumerate(device_list): - # mapping with index - device_map[f"{driver}://{i}"] = get_output_value(device) - # mapping with full path - device_map[f"{driver}://{device['path']}"] = get_output_value(device) - return device_map - - -def map_device_to_name_path(device, key_combination=3): - """Gives the appropriate device data (supported name/path) for user selected execution device - Args: - device (str): user - key_combination (int, optional): choice for mapping value for device name. - 1 : path - 2 : name - 3 : (name, path) - Defaults to 3. - Raises: - ValueError: - Returns: - str / tuple: returns the mapping str or tuple of mapping str for the device depending on key_combination value - """ - driver = device.split("://")[0] - device_map = get_device_mapping(driver, key_combination) - try: - device_mapping = device_map[device] - except KeyError: - raise ValueError(f"Device '{device}' is not a valid device.") - return device_mapping - - -def set_init_device_flags(): - if "vulkan" in args.device: - # set runtime flags for vulkan. - set_iree_runtime_flags() - - # set triple flag to avoid multiple calls to get_vulkan_triple_flag - device_name, args.device = map_device_to_name_path(args.device) - if not args.iree_vulkan_target_triple: - triple = get_vulkan_target_triple(device_name) - if triple is not None: - args.iree_vulkan_target_triple = triple - print( - f"Found device {device_name}. Using target triple {args.iree_vulkan_target_triple}." - ) - elif "cuda" in args.device: - args.device = "cuda" - elif "cpu" in args.device: - args.device = "cpu" - - # set max_length based on availability. - if args.variant in ["anythingv3", "analogdiffusion", "dreamlike"]: - args.max_length = 77 - elif args.variant == "openjourney": - args.max_length = 64 - - # use tuned models only in the case of stablediffusion/fp16 and rdna3 cards. - if ( - args.variant in ["openjourney", "dreamlike"] - or args.precision != "fp16" - or "vulkan" not in args.device - or "rdna3" not in args.iree_vulkan_target_triple - ): - args.use_tuned = False - print("Tuned models are currently not supported for this setting.") - - elif args.use_base_vae and args.variant != "stablediffusion": - args.use_tuned = False - print("Tuned models are currently not supported for this setting.") - - if args.use_tuned: - print("Using tuned models for stablediffusion/fp16 and rdna3 card.") - - -# Utility to get list of devices available. -def get_available_devices(): - def get_devices_by_name(driver_name): - from shark.iree_utils._common import iree_device_map - - device_list = [] - try: - driver_name = iree_device_map(driver_name) - device_list_dict = get_all_devices(driver_name) - print(f"{driver_name} devices are available.") - except: - print(f"{driver_name} devices are not available.") - else: - for i, device in enumerate(device_list_dict): - device_list.append(f"{driver_name}://{i} => {device['name']}") - return device_list - - set_iree_runtime_flags() - - available_devices = [] - vulkan_devices = get_devices_by_name("vulkan") - available_devices.extend(vulkan_devices) - cuda_devices = get_devices_by_name("cuda") - available_devices.extend(cuda_devices) - available_devices.append("cpu") - return available_devices diff --git a/shark/examples/shark_inference/v_diffusion.py b/shark/examples/shark_inference/v_diffusion.py deleted file mode 100644 index e74c8afc9b..0000000000 --- a/shark/examples/shark_inference/v_diffusion.py +++ /dev/null @@ -1,15 +0,0 @@ -from shark.shark_inference import SharkInference -from shark.shark_downloader import download_model - - -mlir_model, func_name, inputs, golden_out = download_model( - "v_diffusion", frontend="torch" -) - -shark_module = SharkInference( - mlir_model, device="vulkan", mlir_dialect="linalg" -) -shark_module.compile() -result = shark_module.forward(inputs) -print("The obtained result via shark is: ", result) -print("The golden result is:", golden_out) diff --git a/shark/examples/shark_training/bert_training.py b/shark/examples/shark_training/bert_training.py deleted file mode 100644 index b0cf43a281..0000000000 --- a/shark/examples/shark_training/bert_training.py +++ /dev/null @@ -1,48 +0,0 @@ -import torch -from torch.nn.utils import stateless -from transformers import AutoTokenizer, AutoModelForSequenceClassification -from shark.shark_trainer import SharkTrainer - - -class MiniLMSequenceClassification(torch.nn.Module): - def __init__(self): - super().__init__() - self.model = AutoModelForSequenceClassification.from_pretrained( - "microsoft/MiniLM-L12-H384-uncased", # The pretrained model. - num_labels=2, # The number of output labels--2 for binary classification. - output_attentions=False, # Whether the model returns attentions weights. - output_hidden_states=False, # Whether the model returns all hidden-states. - torchscript=True, - ) - - def forward(self, tokens): - return self.model.forward(tokens)[0] - - -mod = MiniLMSequenceClassification() - - -def get_sorted_params(named_params): - return [i[1] for i in sorted(named_params.items())] - - -print(dict(mod.named_buffers())) - -inp = (torch.randint(2, (1, 128)),) - - -def forward(params, buffers, args): - params_and_buffers = {**params, **buffers} - stateless.functional_call( - mod, params_and_buffers, args, {} - ).sum().backward() - optim = torch.optim.SGD(get_sorted_params(params), lr=0.01) - # optim.load_state_dict(optim_state) - optim.step() - return params, buffers - - -shark_module = SharkTrainer(mod, inp) -shark_module.compile(forward) -shark_module.train(num_iters=2) -print("training done") diff --git a/shark/examples/shark_training/bert_training_load_tf.py b/shark/examples/shark_training/bert_training_load_tf.py deleted file mode 100644 index ae214c4dd4..0000000000 --- a/shark/examples/shark_training/bert_training_load_tf.py +++ /dev/null @@ -1,60 +0,0 @@ -import numpy as np -import os -import time -import tensorflow as tf - -from shark.shark_trainer import SharkTrainer -from shark.parser import parser -from urllib import request - -parser.add_argument( - "--download_mlir_path", - type=str, - default="bert_tf_training.mlir", - help="Specifies path to target mlir file that will be loaded.", -) -load_args, unknown = parser.parse_known_args() - -tf.random.set_seed(0) -vocab_size = 100 -NUM_CLASSES = 5 -SEQUENCE_LENGTH = 512 -BATCH_SIZE = 1 - -# Download BERT model from tank and train. -if __name__ == "__main__": - predict_sample_input = [ - np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)), - np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)), - np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)), - ] - file_link = "https://storage.googleapis.com/shark_tank/users/stanley/bert_tf_training.mlir" - response = request.urlretrieve(file_link, load_args.download_mlir_path) - sample_input_tensors = [ - tf.convert_to_tensor(val, dtype=tf.int32) - for val in predict_sample_input - ] - num_iter = 10 - if not os.path.isfile(load_args.download_mlir_path): - raise ValueError( - f"Tried looking for target mlir in {load_args.download_mlir_path}, but cannot be found." - ) - with open(load_args.download_mlir_path, "rb") as input_file: - bert_mlir = input_file.read() - shark_module = SharkTrainer( - bert_mlir, - ( - sample_input_tensors, - tf.convert_to_tensor( - np.random.randint(5, size=(BATCH_SIZE)), dtype=tf.int32 - ), - ), - ) - shark_module.set_frontend("mhlo") - shark_module.compile() - start = time.time() - print(shark_module.train(num_iter)) - end = time.time() - total_time = end - start - print("time: " + str(total_time)) - print("time/iter: " + str(total_time / num_iter)) diff --git a/shark/examples/shark_training/bert_training_tf.py b/shark/examples/shark_training/bert_training_tf.py deleted file mode 100644 index 8db49c61b0..0000000000 --- a/shark/examples/shark_training/bert_training_tf.py +++ /dev/null @@ -1,98 +0,0 @@ -from absl import app -import time - -import numpy as np -import tensorflow as tf - -from official.nlp.modeling import layers -from official.nlp.modeling import networks -from official.nlp.modeling.models import bert_classifier - -from shark.shark_trainer import SharkTrainer - - -tf.random.set_seed(0) -vocab_size = 100 -NUM_CLASSES = 5 -SEQUENCE_LENGTH = 512 -BATCH_SIZE = 1 -# Create a set of 2-dimensional inputs -bert_input = [ - tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32), - tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32), - tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32), -] - - -class BertModule(tf.Module): - def __init__(self): - super(BertModule, self).__init__() - dict_outputs = False - test_network = networks.BertEncoder( - vocab_size=vocab_size, num_layers=2, dict_outputs=dict_outputs - ) - - # Create a BERT trainer with the created network. - bert_trainer_model = bert_classifier.BertClassifier( - test_network, num_classes=NUM_CLASSES - ) - bert_trainer_model.summary() - - # Invoke the trainer model on the inputs. This causes the layer to be built. - self.m = bert_trainer_model - self.m.predict = lambda x: self.m.call(x, training=False) - self.predict = tf.function(input_signature=[bert_input])( - self.m.predict - ) - self.m.learn = lambda x, y: self.m.call(x, training=False) - self.loss = tf.keras.losses.SparseCategoricalCrossentropy() - self.optimizer = tf.keras.optimizers.SGD(learning_rate=1e-2) - - @tf.function( - input_signature=[ - bert_input, # inputs - tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32), # labels - ], - jit_compile=True, - ) - def forward(self, inputs, labels): - with tf.GradientTape() as tape: - # Capture the gradients from forward prop... - probs = self.m(inputs, training=True) - loss = self.loss(labels, probs) - - # ...and use them to update the model's weights. - variables = self.m.trainable_variables - gradients = tape.gradient(loss, variables) - self.optimizer.apply_gradients(zip(gradients, variables)) - return loss - - -if __name__ == "__main__": - predict_sample_input = [ - np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)), - np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)), - np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)), - ] - sample_input_tensors = [ - tf.convert_to_tensor(val, dtype=tf.int32) - for val in predict_sample_input - ] - num_iter = 10 - shark_module = SharkTrainer( - BertModule(), - ( - sample_input_tensors, - tf.convert_to_tensor( - np.random.randint(5, size=(BATCH_SIZE)), dtype=tf.int32 - ), - ), - ) - shark_module.set_frontend("tensorflow") - shark_module.compile() - start = time.time() - print(shark_module.train(num_iter)) - end = time.time() - total_time = end - start - print("time: " + str(total_time)) - print("time/iter: " + str(total_time / num_iter)) diff --git a/shark/examples/shark_training/neural_net_training.py b/shark/examples/shark_training/neural_net_training.py deleted file mode 100644 index a28397e0c0..0000000000 --- a/shark/examples/shark_training/neural_net_training.py +++ /dev/null @@ -1,44 +0,0 @@ -import torch -from torch.nn.utils import _stateless -from shark.shark_trainer import SharkTrainer - - -class Foo(torch.nn.Module): - def __init__(self): - super(Foo, self).__init__() - self.l1 = torch.nn.Linear(10, 16) - self.relu = torch.nn.ReLU() - self.l2 = torch.nn.Linear(16, 2) - - def forward(self, x): - out = self.l1(x) - out = self.relu(out) - out = self.l2(out) - return out - - -mod = Foo() -inp = (torch.randn(10, 10),) - - -def get_sorted_params(named_params): - return [i[1] for i in sorted(named_params.items())] - - -def forward(params, buffers, args): - params_and_buffers = {**params, **buffers} - _stateless.functional_call( - mod, params_and_buffers, args, {} - ).sum().backward() - optim = torch.optim.SGD(get_sorted_params(params), lr=0.01) - optim.step() - return params, buffers - - -# fx_graph = forward(dict(mod.named_parameters()), dict(mod.named_buffers()), inp) - -shark_module = SharkTrainer(mod, inp) -# Pass the training function in case of torch -shark_module.compile(training_fn=forward) - -shark_module.train(num_iters=10) diff --git a/shark/examples/shark_training/stable-diffusion-img2img/README.md b/shark/examples/shark_training/stable-diffusion-img2img/README.md deleted file mode 100644 index 8ac562db1f..0000000000 --- a/shark/examples/shark_training/stable-diffusion-img2img/README.md +++ /dev/null @@ -1,41 +0,0 @@ -# Stable Diffusion Img2Img model - -## Installation - -
- Installation (Linux) - -### Activate shark.venv Virtual Environment - -```shell -source shark.venv/bin/activate - -# Some older pip installs may not be able to handle the recent PyTorch deps -python -m pip install --upgrade pip -``` - -### Install dependencies - -# Run the setup.sh script - -```shell -./setup.sh -``` - -### Run the Stable diffusion Img2Img model - -To run the model with the default set of images and params, run: -```shell -python stable_diffusion_img2img.py -``` -To run the model with your set of images, and parameters you need to specify the following params: -1.) Input images directory with the arg `--input_dir` containing 3-5 images. -2.) What to teach the model? Using the arg `--what_to_teach`, allowed values are `object` or `style`. -3.) Placeholder token using the arg `--placeholder_token`, that represents your new concept. It should be passed with the opening and closing angle brackets. For ex: token is `cat-toy`, it should be passed as ``. -4.) Initializer token using the arg `--initializer_token`, which summarise what is your new concept. - -For the result, you need to pass the text prompt with the arg: `--prompt`. The prompt string should contain a "*s" in it, which will be replaced by the placeholder token during the inference. - -By default the result images will go into the `sd_result` dir. To specify your output dir use the arg: `--output_dir`. - -The default value of max_training_steps is `3000`, which takes some hours to complete. You can pass the smaller value with the arg `--training_steps`. Specify the number of images to be sampled for the result with the `--num_inference_samples` arg. diff --git a/shark/examples/shark_training/stable-diffusion-img2img/setup.sh b/shark/examples/shark_training/stable-diffusion-img2img/setup.sh deleted file mode 100644 index 6d62d04d55..0000000000 --- a/shark/examples/shark_training/stable-diffusion-img2img/setup.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/bin/bash - -TD="$(cd $(dirname $0) && pwd)" -if [ -z "$PYTHON" ]; then - PYTHON="$(which python3)" -fi - -function die() { - echo "Error executing command: $*" - exit 1 -} - -PYTHON_VERSION_X_Y=`${PYTHON} -c 'import sys; version=sys.version_info[:2]; print("{0}.{1}".format(*version))'` - -echo "Python: $PYTHON" -echo "Python version: $PYTHON_VERSION_X_Y" - -mkdir input_images - -wget https://huggingface.co/datasets/valhalla/images/resolve/main/2.jpeg -P input_images/ -wget https://huggingface.co/datasets/valhalla/images/resolve/main/3.jpeg -P input_images/ -wget https://huggingface.co/datasets/valhalla/images/resolve/main/5.jpeg -P input_images/ -wget https://huggingface.co/datasets/valhalla/images/resolve/main/6.jpeg -P input_images/ - -pip install diffusers["training"]==0.4.1 transformers ftfy opencv-python diff --git a/shark/examples/shark_training/stable-diffusion-img2img/stable_diffusion_img2img.py b/shark/examples/shark_training/stable-diffusion-img2img/stable_diffusion_img2img.py deleted file mode 100644 index 00ba8ffbfd..0000000000 --- a/shark/examples/shark_training/stable-diffusion-img2img/stable_diffusion_img2img.py +++ /dev/null @@ -1,600 +0,0 @@ -# Textual-inversion fine-tuning for Stable Diffusion using diffusers -# This script shows how to "teach" Stable Diffusion a new concept via -# textual-inversion using 🤗 Hugging Face [🧨 Diffusers library](https://github.com/huggingface/diffusers). -# By using just 3-5 images you can teach new concepts to Stable Diffusion -# and personalize the model on your own images. - -import argparse -import itertools -import math -import os -import random -import cv2 - -import numpy as np -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch.utils.data import Dataset - -import PIL -from accelerate import Accelerator -from accelerate.logging import get_logger -from accelerate.utils import set_seed -from diffusers import ( - AutoencoderKL, - DDPMScheduler, - PNDMScheduler, - StableDiffusionPipeline, - UNet2DConditionModel, -) -from diffusers.hub_utils import init_git_repo, push_to_hub -from diffusers.optimization import get_scheduler -from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker -from PIL import Image -from torchvision import transforms -from tqdm.auto import tqdm -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer - -YOUR_TOKEN = "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" - -p = argparse.ArgumentParser( - description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter -) -p.add_argument( - "--input_dir", - type=str, - default="input_images/", - help="the directory contains the images used for fine tuning", -) -p.add_argument( - "--output_dir", - type=str, - default="sd_result", - help="the directory contains the images used for fine tuning", -) -p.add_argument( - "--training_steps", - type=int, - default=3000, - help="the maximum number of training steps", -) -p.add_argument("--seed", type=int, default=42, help="the random seed") -p.add_argument( - "--what_to_teach", - type=str, - choices=["object", "style"], - default="object", - help="what is it that you are teaching?", -) -p.add_argument( - "--placeholder_token", - type=str, - default="", - help="It is the token you are going to use to represent your new concept", -) -p.add_argument( - "--initializer_token", - type=str, - default="toy", - help="It is a word that can summarise what is your new concept", -) -p.add_argument( - "--inference_steps", - type=int, - default=50, - help="the number of steps for inference", -) -p.add_argument( - "--num_inference_samples", - type=int, - default=4, - help="the number of samples for inference", -) -p.add_argument( - "--prompt", - type=str, - default="a grafitti in a wall with a *s on it", - help="the text prompt to use", -) -args = p.parse_args() - -if "*s" not in args.prompt: - raise ValueError( - f'The prompt should have a "*s" which will be replaced by a placeholder token.' - ) - -prompt1, prompt2 = args.prompt.split("*s") -args.prompt = prompt1 + args.placeholder_token + prompt2 - -pretrained_model_name_or_path = "CompVis/stable-diffusion-v1-4" - -# Load input images. -images = [] -for filename in os.listdir(args.input_dir): - img = cv2.imread(os.path.join(args.input_dir, filename)) - if img is not None: - images.append(img) - -# Setup the prompt templates for training -imagenet_templates_small = [ - "a photo of a {}", - "a rendering of a {}", - "a cropped photo of the {}", - "the photo of a {}", - "a photo of a clean {}", - "a photo of a dirty {}", - "a dark photo of the {}", - "a photo of my {}", - "a photo of the cool {}", - "a close-up photo of a {}", - "a bright photo of the {}", - "a cropped photo of a {}", - "a photo of the {}", - "a good photo of the {}", - "a photo of one {}", - "a close-up photo of the {}", - "a rendition of the {}", - "a photo of the clean {}", - "a rendition of a {}", - "a photo of a nice {}", - "a good photo of a {}", - "a photo of the nice {}", - "a photo of the small {}", - "a photo of the weird {}", - "a photo of the large {}", - "a photo of a cool {}", - "a photo of a small {}", -] - -imagenet_style_templates_small = [ - "a painting in the style of {}", - "a rendering in the style of {}", - "a cropped painting in the style of {}", - "the painting in the style of {}", - "a clean painting in the style of {}", - "a dirty painting in the style of {}", - "a dark painting in the style of {}", - "a picture in the style of {}", - "a cool painting in the style of {}", - "a close-up painting in the style of {}", - "a bright painting in the style of {}", - "a cropped painting in the style of {}", - "a good painting in the style of {}", - "a close-up painting in the style of {}", - "a rendition in the style of {}", - "a nice painting in the style of {}", - "a small painting in the style of {}", - "a weird painting in the style of {}", - "a large painting in the style of {}", -] - - -# Setup the dataset -class TextualInversionDataset(Dataset): - def __init__( - self, - data_root, - tokenizer, - learnable_property="object", # [object, style] - size=512, - repeats=100, - interpolation="bicubic", - flip_p=0.5, - set="train", - placeholder_token="*", - center_crop=False, - ): - self.data_root = data_root - self.tokenizer = tokenizer - self.learnable_property = learnable_property - self.size = size - self.placeholder_token = placeholder_token - self.center_crop = center_crop - self.flip_p = flip_p - - self.image_paths = [ - os.path.join(self.data_root, file_path) - for file_path in os.listdir(self.data_root) - ] - - self.num_images = len(self.image_paths) - self._length = self.num_images - - if set == "train": - self._length = self.num_images * repeats - - self.interpolation = { - "linear": PIL.Image.LINEAR, - "bilinear": PIL.Image.BILINEAR, - "bicubic": PIL.Image.BICUBIC, - "lanczos": PIL.Image.LANCZOS, - }[interpolation] - - self.templates = ( - imagenet_style_templates_small - if learnable_property == "style" - else imagenet_templates_small - ) - self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p) - - def __len__(self): - return self._length - - def __getitem__(self, i): - example = {} - image = Image.open(self.image_paths[i % self.num_images]) - - if not image.mode == "RGB": - image = image.convert("RGB") - - placeholder_string = self.placeholder_token - text = random.choice(self.templates).format(placeholder_string) - - example["input_ids"] = self.tokenizer( - text, - padding="max_length", - truncation=True, - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ).input_ids[0] - - # default to score-sde preprocessing - img = np.array(image).astype(np.uint8) - - if self.center_crop: - crop = min(img.shape[0], img.shape[1]) - ( - h, - w, - ) = ( - img.shape[0], - img.shape[1], - ) - img = img[ - (h - crop) // 2 : (h + crop) // 2, - (w - crop) // 2 : (w + crop) // 2, - ] - - image = Image.fromarray(img) - image = image.resize( - (self.size, self.size), resample=self.interpolation - ) - - image = self.flip_transform(image) - image = np.array(image).astype(np.uint8) - image = (image / 127.5 - 1.0).astype(np.float32) - - example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) - return example - - -# Setting up the model -# Load the tokenizer and add the placeholder token as a additional special token. -# Please read and if you agree accept the LICENSE -# [here](https://huggingface.co/CompVis/stable-diffusion-v1-4) if you see an error -tokenizer = CLIPTokenizer.from_pretrained( - pretrained_model_name_or_path, - subfolder="tokenizer", - use_auth_token=YOUR_TOKEN, -) - -# Add the placeholder token in tokenizer -num_added_tokens = tokenizer.add_tokens(args.placeholder_token) -if num_added_tokens == 0: - raise ValueError( - f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" - " `placeholder_token` that is not already in the tokenizer." - ) - -# Get token ids for our placeholder and initializer token. -# This code block will complain if initializer string is not a single token -# Convert the initializer_token, placeholder_token to ids -token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) -# Check if initializer_token is a single token or a sequence of tokens -if len(token_ids) > 1: - raise ValueError("The initializer token must be a single token.") - -initializer_token_id = token_ids[0] -placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) - -# Load the Stable Diffusion model -# Load models and create wrapper for stable diffusion -text_encoder = CLIPTextModel.from_pretrained( - pretrained_model_name_or_path, - subfolder="text_encoder", - use_auth_token=YOUR_TOKEN, -) -vae = AutoencoderKL.from_pretrained( - pretrained_model_name_or_path, - subfolder="vae", - use_auth_token=YOUR_TOKEN, -) -unet = UNet2DConditionModel.from_pretrained( - pretrained_model_name_or_path, - subfolder="unet", - use_auth_token=YOUR_TOKEN, -) - -# We have added the `placeholder_token` in the `tokenizer` so we resize the token embeddings here, -# this will a new embedding vector in the token embeddings for our `placeholder_token` -text_encoder.resize_token_embeddings(len(tokenizer)) - -# Initialise the newly added placeholder token with the embeddings of the initializer token -token_embeds = text_encoder.get_input_embeddings().weight.data -token_embeds[placeholder_token_id] = token_embeds[initializer_token_id] - -# In Textual-Inversion we only train the newly added embedding vector, -# so lets freeze rest of the model parameters here. - - -def freeze_params(params): - for param in params: - param.requires_grad = False - - -# Freeze vae and unet -freeze_params(vae.parameters()) -freeze_params(unet.parameters()) -# Freeze all parameters except for the token embeddings in text encoder -params_to_freeze = itertools.chain( - text_encoder.text_model.encoder.parameters(), - text_encoder.text_model.final_layer_norm.parameters(), - text_encoder.text_model.embeddings.position_embedding.parameters(), -) -freeze_params(params_to_freeze) - -# Creating our training data - -train_dataset = TextualInversionDataset( - data_root=args.input_dir, - tokenizer=tokenizer, - size=512, - placeholder_token=args.placeholder_token, - repeats=100, - learnable_property=args.what_to_teach, # Option selected above between object and style - center_crop=False, - set="train", -) - - -def create_dataloader(train_batch_size=1): - return torch.utils.data.DataLoader( - train_dataset, batch_size=train_batch_size, shuffle=True - ) - - -# Create noise_scheduler for training. -noise_scheduler = DDPMScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - num_train_timesteps=1000, - tensor_format="pt", -) - -# Define hyperparameters for our training -hyperparameters = { - "learning_rate": 5e-04, - "scale_lr": True, - "max_train_steps": args.training_steps, - "train_batch_size": 1, - "gradient_accumulation_steps": 4, - "seed": args.seed, - "output_dir": "sd-concept-output", -} - - -def training_function(text_encoder, vae, unet): - logger = get_logger(__name__) - - train_batch_size = hyperparameters["train_batch_size"] - gradient_accumulation_steps = hyperparameters[ - "gradient_accumulation_steps" - ] - learning_rate = hyperparameters["learning_rate"] - max_train_steps = hyperparameters["max_train_steps"] - output_dir = hyperparameters["output_dir"] - - accelerator = Accelerator( - gradient_accumulation_steps=gradient_accumulation_steps, - ) - - train_dataloader = create_dataloader(train_batch_size) - - if hyperparameters["scale_lr"]: - learning_rate = ( - learning_rate - * gradient_accumulation_steps - * train_batch_size - * accelerator.num_processes - ) - - # Initialize the optimizer - optimizer = torch.optim.AdamW( - text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings - lr=learning_rate, - ) - - text_encoder, optimizer, train_dataloader = accelerator.prepare( - text_encoder, optimizer, train_dataloader - ) - - # Move vae and unet to device - vae.to(accelerator.device) - unet.to(accelerator.device) - - # Keep vae and unet in eval model as we don't train these - vae.eval() - unet.eval() - - # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil( - len(train_dataloader) / gradient_accumulation_steps - ) - num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) - - # Train! - total_batch_size = ( - train_batch_size - * accelerator.num_processes - * gradient_accumulation_steps - ) - - logger.info("***** Running training *****") - logger.info(f" Num examples = {len(train_dataset)}") - logger.info(f" Instantaneous batch size per device = {train_batch_size}") - logger.info( - f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" - ) - logger.info( - f" Gradient Accumulation steps = {gradient_accumulation_steps}" - ) - logger.info(f" Total optimization steps = {max_train_steps}") - # Only show the progress bar once on each machine. - progress_bar = tqdm( - range(max_train_steps), disable=not accelerator.is_local_main_process - ) - progress_bar.set_description("Steps") - global_step = 0 - - for epoch in range(num_train_epochs): - text_encoder.train() - for step, batch in enumerate(train_dataloader): - with accelerator.accumulate(text_encoder): - # Convert images to latent space - latents = ( - vae.encode(batch["pixel_values"]) - .latent_dist.sample() - .detach() - ) - latents = latents * 0.18215 - - # Sample noise that we'll add to the latents - noise = torch.randn(latents.shape).to(latents.device) - bsz = latents.shape[0] - # Sample a random timestep for each image - timesteps = torch.randint( - 0, - noise_scheduler.num_train_timesteps, - (bsz,), - device=latents.device, - ).long() - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise( - latents, noise, timesteps - ) - - # Get the text embedding for conditioning - encoder_hidden_states = text_encoder(batch["input_ids"])[0] - - # Predict the noise residual - noise_pred = unet( - noisy_latents, timesteps, encoder_hidden_states - ).sample - - loss = ( - F.mse_loss(noise_pred, noise, reduction="none") - .mean([1, 2, 3]) - .mean() - ) - accelerator.backward(loss) - - # Zero out the gradients for all token embeddings except the newly added - # embeddings for the concept, as we only want to optimize the concept embeddings - if accelerator.num_processes > 1: - grads = ( - text_encoder.module.get_input_embeddings().weight.grad - ) - else: - grads = text_encoder.get_input_embeddings().weight.grad - # Get the index for tokens that we want to zero the grads for - index_grads_to_zero = ( - torch.arange(len(tokenizer)) != placeholder_token_id - ) - grads.data[index_grads_to_zero, :] = grads.data[ - index_grads_to_zero, : - ].fill_(0) - - optimizer.step() - optimizer.zero_grad() - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 - - logs = {"loss": loss.detach().item()} - progress_bar.set_postfix(**logs) - - if global_step >= max_train_steps: - break - - accelerator.wait_for_everyone() - - # Create the pipeline using using the trained modules and save it. - if accelerator.is_main_process: - pipeline = StableDiffusionPipeline( - text_encoder=accelerator.unwrap_model(text_encoder), - vae=vae, - unet=unet, - tokenizer=tokenizer, - scheduler=PNDMScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - skip_prk_steps=True, - ), - safety_checker=StableDiffusionSafetyChecker.from_pretrained( - "CompVis/stable-diffusion-safety-checker" - ), - feature_extractor=CLIPFeatureExtractor.from_pretrained( - "openai/clip-vit-base-patch32" - ), - ) - pipeline.save_pretrained(output_dir) - # Also save the newly trained embeddings - learned_embeds = ( - accelerator.unwrap_model(text_encoder) - .get_input_embeddings() - .weight[placeholder_token_id] - ) - learned_embeds_dict = { - args.placeholder_token: learned_embeds.detach().cpu() - } - torch.save( - learned_embeds_dict, os.path.join(output_dir, "learned_embeds.bin") - ) - - -import accelerate - -accelerate.notebook_launcher( - training_function, args=(text_encoder, vae, unet), num_processes=1 -) - -# Set up the pipeline -pipe = StableDiffusionPipeline.from_pretrained( - hyperparameters["output_dir"], - # torch_dtype=torch.float16, -) - -all_images = [] -for _ in range(args.num_inference_samples): - images = pipe( - [args.prompt], - num_inference_steps=args.inference_steps, - guidance_scale=7.5, - ).images - all_images.extend(images) - -# output_path = os.path.abspath(os.path.join(os.getcwd(), args.output_dir)) -if not os.path.isdir(args.output_dir): - os.mkdir(args.output_dir) - -[ - image.save(f"{args.output_dir}/{i}.jpeg") - for i, image in enumerate(all_images) -] diff --git a/shark/examples/shark_training/stable_diffusion/README.md b/shark/examples/shark_training/stable_diffusion/README.md deleted file mode 100644 index 3d6b848d6e..0000000000 --- a/shark/examples/shark_training/stable_diffusion/README.md +++ /dev/null @@ -1,43 +0,0 @@ -# Stable Diffusion Fine Tuning - -## Installation (Linux) - -### Activate shark.venv Virtual Environment - -```shell -source shark.venv/bin/activate - -# Some older pip installs may not be able to handle the recent PyTorch deps -python -m pip install --upgrade pip -``` - -## Install dependencies - -### Run the following installation commands: -``` -pip install -U git+https://github.com/huggingface/diffusers.git -pip install accelerate transformers ftfy -``` - -### Build torch-mlir with the following branch: - -Please cherry-pick this branch of torch-mlir: https://github.com/vivekkhandelwal1/torch-mlir/tree/sd-ops -and build it locally. You can find the instructions for using locally build Torch-MLIR, -here: https://github.com/nod-ai/SHARK#how-to-use-your-locally-built-iree--torch-mlir-with-shark - -## Run the Stable diffusion fine tuning - -To run the model with the default set of images and params, run: -```shell -python stable_diffusion_fine_tuning.py -``` -By default the training is run through the PyTorch path. If you want to train the model using the Torchdynamo path of Torch-MLIR, you need to specify `--use_torchdynamo=True`. - -The default number of training steps are `2000`, which would take many hours to complete based on your system config. You can pass the smaller value with the arg `--training_steps`. You can specify the number of images to be sampled for the result with the `--num_inference_samples` arg. For the number of inference steps you can use `--inference_steps` flag. - -For example, you can run the training for a limited set of steps via the dynamo path by using the following command: -``` -python stable_diffusion_fine_tuning.py --training_steps=1 --inference_steps=1 --num_inference_samples=1 --train_batch_size=1 --use_torchdynamo=True -``` - -You can also specify the device to be used via the flag `--device`. The default value is `cpu`, for GPU execution you can specify `--device="cuda"`. diff --git a/shark/examples/shark_training/stable_diffusion/stable_diffusion_fine_tuning.py b/shark/examples/shark_training/stable_diffusion/stable_diffusion_fine_tuning.py deleted file mode 100644 index e16a25dd8b..0000000000 --- a/shark/examples/shark_training/stable_diffusion/stable_diffusion_fine_tuning.py +++ /dev/null @@ -1,914 +0,0 @@ -# Install the required libs -# pip install -U git+https://github.com/huggingface/diffusers.git -# pip install accelerate transformers ftfy - -# Import required libraries -import argparse -import itertools -import math -import os -from typing import List -import random - -import numpy as np -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch.utils.data import Dataset - -import PIL -import logging - -import torch_mlir -from torch_mlir.dynamo import make_simple_dynamo_backend -import torch._dynamo as dynamo -from torch.fx.experimental.proxy_tensor import make_fx -from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend -from shark.shark_inference import SharkInference - -torch._dynamo.config.verbose = True - -from diffusers import ( - AutoencoderKL, - DDPMScheduler, - PNDMScheduler, - StableDiffusionPipeline, - UNet2DConditionModel, -) -from diffusers.optimization import get_scheduler -from diffusers.pipelines.stable_diffusion import ( - StableDiffusionSafetyChecker, -) -from PIL import Image -from torchvision import transforms -from tqdm.auto import tqdm -from transformers import ( - CLIPFeatureExtractor, - CLIPTextModel, - CLIPTokenizer, -) - - -# Enter your HuggingFace Token -# Note: You can comment this prompt and just set your token instead of passing it through cli for every execution. -hf_token = input("Please enter your huggingface token here: ") -YOUR_TOKEN = hf_token - - -def image_grid(imgs, rows, cols): - assert len(imgs) == rows * cols - - w, h = imgs[0].size - grid = Image.new("RGB", size=(cols * w, rows * h)) - grid_w, grid_h = grid.size - - for i, img in enumerate(imgs): - grid.paste(img, box=(i % cols * w, i // cols * h)) - return grid - - -# `pretrained_model_name_or_path` which Stable Diffusion checkpoint you want to use -# Options: 1.) "stabilityai/stable-diffusion-2" -# 2.) "stabilityai/stable-diffusion-2-base" -# 3.) "CompVis/stable-diffusion-v1-4" -# 4.) "runwayml/stable-diffusion-v1-5" -pretrained_model_name_or_path = "stabilityai/stable-diffusion-2" - -# Add here the URLs to the images of the concept you are adding. 3-5 should be fine -urls = [ - "https://huggingface.co/datasets/valhalla/images/resolve/main/2.jpeg", - "https://huggingface.co/datasets/valhalla/images/resolve/main/3.jpeg", - "https://huggingface.co/datasets/valhalla/images/resolve/main/5.jpeg", - "https://huggingface.co/datasets/valhalla/images/resolve/main/6.jpeg", - ## You can add additional images here -] - -# Downloading Images -import requests -import glob -from io import BytesIO - - -def download_image(url): - try: - response = requests.get(url) - except: - return None - return Image.open(BytesIO(response.content)).convert("RGB") - - -images = list(filter(None, [download_image(url) for url in urls])) -save_path = "./my_concept" -if not os.path.exists(save_path): - os.mkdir(save_path) -[image.save(f"{save_path}/{i}.jpeg") for i, image in enumerate(images)] - -p = argparse.ArgumentParser( - description=__doc__, - formatter_class=argparse.ArgumentDefaultsHelpFormatter, -) -p.add_argument( - "--input_dir", - type=str, - default="my_concept/", - help="the directory contains the images used for fine tuning", -) -p.add_argument( - "--output_dir", - type=str, - default="sd_result", - help="the directory contains the images used for fine tuning", -) -p.add_argument( - "--training_steps", - type=int, - default=2000, - help="the maximum number of training steps", -) -p.add_argument( - "--train_batch_size", - type=int, - default=4, - help="The batch size for training", -) -p.add_argument( - "--save_steps", - type=int, - default=250, - help="the number of steps after which to save the learned concept", -) -p.add_argument("--seed", type=int, default=42, help="the random seed") -p.add_argument( - "--what_to_teach", - type=str, - choices=["object", "style"], - default="object", - help="what is it that you are teaching?", -) -p.add_argument( - "--placeholder_token", - type=str, - default="", - help="It is the token you are going to use to represent your new concept", -) -p.add_argument( - "--initializer_token", - type=str, - default="toy", - help="It is a word that can summarise what is your new concept", -) -p.add_argument( - "--inference_steps", - type=int, - default=50, - help="the number of steps for inference", -) -p.add_argument( - "--num_inference_samples", - type=int, - default=4, - help="the number of samples for inference", -) -p.add_argument( - "--prompt", - type=str, - default="a grafitti in a wall with a *s on it", - help="the text prompt to use", -) -p.add_argument( - "--device", - type=str, - default="cpu", - help="The device to use", -) -p.add_argument( - "--use_torchdynamo", - type=bool, - default=False, - help="This flag is used to determine whether the training has to be done through the torchdynamo path or not.", -) -args = p.parse_args() -torch.manual_seed(args.seed) - -if "*s" not in args.prompt: - raise ValueError( - f'The prompt should have a "*s" which will be replaced by a placeholder token.' - ) - -prompt1, prompt2 = args.prompt.split("*s") -args.prompt = prompt1 + args.placeholder_token + prompt2 - -# `images_path` is a path to directory containing the training images. -images_path = args.input_dir -while not os.path.exists(str(images_path)): - print( - "The images_path specified does not exist, use the colab file explorer to copy the path :" - ) - images_path = input("") -save_path = images_path - -# Setup and check the images you have just added -images = [] -for file_path in os.listdir(save_path): - try: - image_path = os.path.join(save_path, file_path) - images.append(Image.open(image_path).resize((512, 512))) - except: - print( - f"{image_path} is not a valid image, please make sure to remove this file from the directory otherwise the training could fail." - ) -image_grid(images, 1, len(images)) - -########### Create Dataset ########## - -# Setup the prompt templates for training -imagenet_templates_small = [ - "a photo of a {}", - "a rendering of a {}", - "a cropped photo of the {}", - "the photo of a {}", - "a photo of a clean {}", - "a photo of a dirty {}", - "a dark photo of the {}", - "a photo of my {}", - "a photo of the cool {}", - "a close-up photo of a {}", - "a bright photo of the {}", - "a cropped photo of a {}", - "a photo of the {}", - "a good photo of the {}", - "a photo of one {}", - "a close-up photo of the {}", - "a rendition of the {}", - "a photo of the clean {}", - "a rendition of a {}", - "a photo of a nice {}", - "a good photo of a {}", - "a photo of the nice {}", - "a photo of the small {}", - "a photo of the weird {}", - "a photo of the large {}", - "a photo of a cool {}", - "a photo of a small {}", -] - -imagenet_style_templates_small = [ - "a painting in the style of {}", - "a rendering in the style of {}", - "a cropped painting in the style of {}", - "the painting in the style of {}", - "a clean painting in the style of {}", - "a dirty painting in the style of {}", - "a dark painting in the style of {}", - "a picture in the style of {}", - "a cool painting in the style of {}", - "a close-up painting in the style of {}", - "a bright painting in the style of {}", - "a cropped painting in the style of {}", - "a good painting in the style of {}", - "a close-up painting in the style of {}", - "a rendition in the style of {}", - "a nice painting in the style of {}", - "a small painting in the style of {}", - "a weird painting in the style of {}", - "a large painting in the style of {}", -] - - -# Setup the dataset -class TextualInversionDataset(Dataset): - def __init__( - self, - data_root, - tokenizer, - learnable_property="object", # [object, style] - size=512, - repeats=100, - interpolation="bicubic", - flip_p=0.5, - set="train", - placeholder_token="*", - center_crop=False, - ): - self.data_root = data_root - self.tokenizer = tokenizer - self.learnable_property = learnable_property - self.size = size - self.placeholder_token = placeholder_token - self.center_crop = center_crop - self.flip_p = flip_p - - self.image_paths = [ - os.path.join(self.data_root, file_path) - for file_path in os.listdir(self.data_root) - ] - - self.num_images = len(self.image_paths) - self._length = self.num_images - - if set == "train": - self._length = self.num_images * repeats - - self.interpolation = { - "linear": PIL.Image.LINEAR, - "bilinear": PIL.Image.BILINEAR, - "bicubic": PIL.Image.BICUBIC, - "lanczos": PIL.Image.LANCZOS, - }[interpolation] - - self.templates = ( - imagenet_style_templates_small - if learnable_property == "style" - else imagenet_templates_small - ) - self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p) - - def __len__(self): - return self._length - - def __getitem__(self, i): - example = {} - image = Image.open(self.image_paths[i % self.num_images]) - - if not image.mode == "RGB": - image = image.convert("RGB") - - placeholder_string = self.placeholder_token - text = random.choice(self.templates).format(placeholder_string) - - example["input_ids"] = self.tokenizer( - text, - padding="max_length", - truncation=True, - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ).input_ids[0] - - # default to score-sde preprocessing - img = np.array(image).astype(np.uint8) - - if self.center_crop: - crop = min(img.shape[0], img.shape[1]) - ( - h, - w, - ) = ( - img.shape[0], - img.shape[1], - ) - img = img[ - (h - crop) // 2 : (h + crop) // 2, - (w - crop) // 2 : (w + crop) // 2, - ] - - image = Image.fromarray(img) - image = image.resize( - (self.size, self.size), resample=self.interpolation - ) - - image = self.flip_transform(image) - image = np.array(image).astype(np.uint8) - image = (image / 127.5 - 1.0).astype(np.float32) - - example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) - return example - - -########## Setting up the model ########## - -# Load the tokenizer and add the placeholder token as a additional special token. -tokenizer = CLIPTokenizer.from_pretrained( - pretrained_model_name_or_path, - subfolder="tokenizer", -) - -# Add the placeholder token in tokenizer -num_added_tokens = tokenizer.add_tokens(args.placeholder_token) -if num_added_tokens == 0: - raise ValueError( - f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" - " `placeholder_token` that is not already in the tokenizer." - ) - -# Get token ids for our placeholder and initializer token. -# This code block will complain if initializer string is not a single token -# Convert the initializer_token, placeholder_token to ids -token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) -# Check if initializer_token is a single token or a sequence of tokens -if len(token_ids) > 1: - raise ValueError("The initializer token must be a single token.") - -initializer_token_id = token_ids[0] -placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) - -# Load the Stable Diffusion model -# Load models and create wrapper for stable diffusion -# pipeline = StableDiffusionPipeline.from_pretrained(pretrained_model_name_or_path) -# del pipeline -text_encoder = CLIPTextModel.from_pretrained( - pretrained_model_name_or_path, subfolder="text_encoder" -) -vae = AutoencoderKL.from_pretrained( - pretrained_model_name_or_path, subfolder="vae" -) -unet = UNet2DConditionModel.from_pretrained( - pretrained_model_name_or_path, subfolder="unet" -) - -# We have added the placeholder_token in the tokenizer so we resize the token embeddings here -# this will a new embedding vector in the token embeddings for our placeholder_token -text_encoder.resize_token_embeddings(len(tokenizer)) - -# Initialise the newly added placeholder token with the embeddings of the initializer token -token_embeds = text_encoder.get_input_embeddings().weight.data -token_embeds[placeholder_token_id] = token_embeds[initializer_token_id] - -# In Textual-Inversion we only train the newly added embedding vector -# so lets freeze rest of the model parameters here - - -def freeze_params(params): - for param in params: - param.requires_grad = False - - -# Freeze vae and unet -freeze_params(vae.parameters()) -freeze_params(unet.parameters()) -# Freeze all parameters except for the token embeddings in text encoder -params_to_freeze = itertools.chain( - text_encoder.text_model.encoder.parameters(), - text_encoder.text_model.final_layer_norm.parameters(), - text_encoder.text_model.embeddings.position_embedding.parameters(), -) -freeze_params(params_to_freeze) - - -# Move vae and unet to device -# For the dynamo path default compilation device is `cpu`, since torch-mlir -# supports only that. Therefore, convert to device only for PyTorch path. -if not args.use_torchdynamo: - vae.to(args.device) - unet.to(args.device) - -# Keep vae in eval mode as we don't train it -vae.eval() -# Keep unet in train mode to enable gradient checkpointing -unet.train() - - -class VaeModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.vae = vae - - def forward(self, input): - x = self.vae.encode(input, return_dict=False)[0] - return x - - -class UnetModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.unet = unet - - def forward(self, x, y, z): - return self.unet.forward(x, y, z, return_dict=False)[0] - - -shark_vae = VaeModel() -shark_unet = UnetModel() - -####### Creating our training data ######## - -# Let's create the Dataset and Dataloader -train_dataset = TextualInversionDataset( - data_root=save_path, - tokenizer=tokenizer, - size=vae.sample_size, - placeholder_token=args.placeholder_token, - repeats=100, - learnable_property=args.what_to_teach, # Option selected above between object and style - center_crop=False, - set="train", -) - - -def create_dataloader(train_batch_size=1): - return torch.utils.data.DataLoader( - train_dataset, batch_size=train_batch_size, shuffle=True - ) - - -# Create noise_scheduler for training -noise_scheduler = DDPMScheduler.from_config( - pretrained_model_name_or_path, subfolder="scheduler" -) - -######## Training ########### - -# Define hyperparameters for our training. If you are not happy with your results, -# you can tune the `learning_rate` and the `max_train_steps` - -# Setting up all training args -hyperparameters = { - "learning_rate": 5e-04, - "scale_lr": True, - "max_train_steps": args.training_steps, - "save_steps": args.save_steps, - "train_batch_size": args.train_batch_size, - "gradient_accumulation_steps": 1, - "gradient_checkpointing": True, - "mixed_precision": "fp16", - "seed": 42, - "output_dir": "sd-concept-output", -} -# creating output directory -cwd = os.getcwd() -out_dir = os.path.join(cwd, hyperparameters["output_dir"]) -while not os.path.exists(str(out_dir)): - try: - os.mkdir(out_dir) - except OSError as error: - print("Output directory not created") - -###### Torch-MLIR Compilation ###### - - -def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]: - removed_indexes = [] - for node in fx_g.graph.nodes: - if node.op == "output": - assert ( - len(node.args) == 1 - ), "Output node must have a single argument" - node_arg = node.args[0] - if isinstance(node_arg, (list, tuple)): - node_arg = list(node_arg) - node_args_len = len(node_arg) - for i in range(node_args_len): - curr_index = node_args_len - (i + 1) - if node_arg[curr_index] is None: - removed_indexes.append(curr_index) - node_arg.pop(curr_index) - node.args = (tuple(node_arg),) - break - - if len(removed_indexes) > 0: - fx_g.graph.lint() - fx_g.graph.eliminate_dead_code() - fx_g.recompile() - removed_indexes.sort() - return removed_indexes - - -def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool: - """ - Replace tuple with tuple element in functions that return one-element tuples. - Returns true if an unwrapping took place, and false otherwise. - """ - unwrapped_tuple = False - for node in fx_g.graph.nodes: - if node.op == "output": - assert ( - len(node.args) == 1 - ), "Output node must have a single argument" - node_arg = node.args[0] - if isinstance(node_arg, tuple): - if len(node_arg) == 1: - node.args = (node_arg[0],) - unwrapped_tuple = True - break - - if unwrapped_tuple: - fx_g.graph.lint() - fx_g.recompile() - return unwrapped_tuple - - -def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool: - for node in fx_g.graph.nodes: - if node.op == "output": - assert ( - len(node.args) == 1 - ), "Output node must have a single argument" - node_arg = node.args[0] - if isinstance(node_arg, tuple): - return len(node_arg) == 0 - return False - - -def transform_fx(fx_g): - for node in fx_g.graph.nodes: - if node.op == "call_function": - if node.target in [ - torch.ops.aten.empty, - ]: - # aten.empty should be filled with zeros. - if node.target in [torch.ops.aten.empty]: - with fx_g.graph.inserting_after(node): - new_node = fx_g.graph.call_function( - torch.ops.aten.zero_, - args=(node,), - ) - node.append(new_node) - node.replace_all_uses_with(new_node) - new_node.args = (node,) - - fx_g.graph.lint() - - -@make_simple_dynamo_backend -def refbackend_torchdynamo_backend( - fx_graph: torch.fx.GraphModule, example_inputs: List[torch.Tensor] -): - # handling usage of empty tensor without initializing - transform_fx(fx_graph) - fx_graph.recompile() - if _returns_nothing(fx_graph): - return fx_graph - removed_none_indexes = _remove_nones(fx_graph) - was_unwrapped = _unwrap_single_tuple_return(fx_graph) - - mlir_module = torch_mlir.compile( - fx_graph, example_inputs, output_type="linalg-on-tensors" - ) - - bytecode_stream = BytesIO() - mlir_module.operation.write_bytecode(bytecode_stream) - bytecode = bytecode_stream.getvalue() - - shark_module = SharkInference( - mlir_module=bytecode, device=args.device, mlir_dialect="tm_tensor" - ) - shark_module.compile() - - def compiled_callable(*inputs): - inputs = [x.numpy() for x in inputs] - result = shark_module("forward", inputs) - if was_unwrapped: - result = [ - result, - ] - if not isinstance(result, list): - result = torch.from_numpy(result) - else: - result = tuple(torch.from_numpy(x) for x in result) - result = list(result) - for removed_index in removed_none_indexes: - result.insert(removed_index, None) - result = tuple(result) - return result - - return compiled_callable - - -def predictions(torch_func, jit_func, batchA, batchB): - res = jit_func(batchA.numpy(), batchB.numpy()) - if res is not None: - prediction = res - else: - prediction = None - return prediction - - -logger = logging.getLogger(__name__) - - -# def save_progress(text_encoder, placeholder_token_id, accelerator, save_path): -def save_progress(text_encoder, placeholder_token_id, save_path): - logger.info("Saving embeddings") - learned_embeds = ( - # accelerator.unwrap_model(text_encoder) - text_encoder.get_input_embeddings().weight[placeholder_token_id] - ) - learned_embeds_dict = { - args.placeholder_token: learned_embeds.detach().cpu() - } - torch.save(learned_embeds_dict, save_path) - - -train_batch_size = hyperparameters["train_batch_size"] -gradient_accumulation_steps = hyperparameters["gradient_accumulation_steps"] -learning_rate = hyperparameters["learning_rate"] -if hyperparameters["scale_lr"]: - learning_rate = ( - learning_rate - * gradient_accumulation_steps - * train_batch_size - # * accelerator.num_processes - ) - -# Initialize the optimizer -optimizer = torch.optim.AdamW( - text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings - lr=learning_rate, -) - - -# Training function -def train_func(batch_pixel_values, batch_input_ids): - # Convert images to latent space - latents = shark_vae(batch_pixel_values).sample().detach() - latents = latents * 0.18215 - - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) - bsz = latents.shape[0] - # Sample a random timestep for each image - timesteps = torch.randint( - 0, - noise_scheduler.num_train_timesteps, - (bsz,), - device=latents.device, - ).long() - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - # Get the text embedding for conditioning - encoder_hidden_states = text_encoder(batch_input_ids)[0] - - # Predict the noise residual - noise_pred = shark_unet( - noisy_latents, - timesteps, - encoder_hidden_states, - ) - - # Get the target for loss depending on the prediction type - if noise_scheduler.config.prediction_type == "epsilon": - target = noise - elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - raise ValueError( - f"Unknown prediction type {noise_scheduler.config.prediction_type}" - ) - - loss = ( - F.mse_loss(noise_pred, target, reduction="none").mean([1, 2, 3]).mean() - ) - loss.backward() - - # Zero out the gradients for all token embeddings except the newly added - # embeddings for the concept, as we only want to optimize the concept embeddings - grads = text_encoder.get_input_embeddings().weight.grad - # Get the index for tokens that we want to zero the grads for - index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id - grads.data[index_grads_to_zero, :] = grads.data[ - index_grads_to_zero, : - ].fill_(0) - - optimizer.step() - optimizer.zero_grad() - - return loss - - -def training_function(): - max_train_steps = hyperparameters["max_train_steps"] - output_dir = hyperparameters["output_dir"] - gradient_checkpointing = hyperparameters["gradient_checkpointing"] - - train_dataloader = create_dataloader(train_batch_size) - - # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil( - len(train_dataloader) / gradient_accumulation_steps - ) - num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) - - # Train! - total_batch_size = ( - train_batch_size - * gradient_accumulation_steps - # train_batch_size * accelerator.num_processes * gradient_accumulation_steps - ) - - logger.info("***** Running training *****") - logger.info(f" Num examples = {len(train_dataset)}") - logger.info(f" Instantaneous batch size per device = {train_batch_size}") - logger.info( - f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" - ) - logger.info( - f" Gradient Accumulation steps = {gradient_accumulation_steps}" - ) - logger.info(f" Total optimization steps = {max_train_steps}") - # Only show the progress bar once on each machine. - progress_bar = tqdm( - # range(max_train_steps), disable=not accelerator.is_local_main_process - range(max_train_steps) - ) - progress_bar.set_description("Steps") - global_step = 0 - - params_ = [i for i in text_encoder.get_input_embeddings().parameters()] - if args.use_torchdynamo: - print("******** TRAINING STARTED - TORCHYDNAMO PATH ********") - else: - print("******** TRAINING STARTED - PYTORCH PATH ********") - print("Initial weights:") - print(params_, params_[0].shape) - - for epoch in range(num_train_epochs): - text_encoder.train() - for step, batch in enumerate(train_dataloader): - if args.use_torchdynamo: - dynamo_callable = dynamo.optimize( - refbackend_torchdynamo_backend - )(train_func) - lam_func = lambda x, y: dynamo_callable( - torch.from_numpy(x), torch.from_numpy(y) - ) - loss = predictions( - train_func, - lam_func, - batch["pixel_values"], - batch["input_ids"], - # params[0].detach(), - ) - else: - loss = train_func(batch["pixel_values"], batch["input_ids"]) - print(loss) - - # Checks if the accelerator has performed an optimization step behind the scenes - progress_bar.update(1) - global_step += 1 - if global_step % hyperparameters["save_steps"] == 0: - save_path = os.path.join( - output_dir, - f"learned_embeds-step-{global_step}.bin", - ) - save_progress( - text_encoder, - placeholder_token_id, - save_path, - ) - - logs = {"loss": loss.detach().item()} - progress_bar.set_postfix(**logs) - - if global_step >= max_train_steps: - break - - # Create the pipeline using using the trained modules and save it. - params__ = [i for i in text_encoder.get_input_embeddings().parameters()] - print("******** TRAINING PROCESS FINISHED ********") - print("Updated weights:") - print(params__, params__[0].shape) - pipeline = StableDiffusionPipeline.from_pretrained( - pretrained_model_name_or_path, - # text_encoder=accelerator.unwrap_model(text_encoder), - text_encoder=text_encoder, - tokenizer=tokenizer, - vae=vae, - unet=unet, - ) - pipeline.save_pretrained(output_dir) - # Also save the newly trained embeddings - save_path = os.path.join(output_dir, f"learned_embeds.bin") - save_progress(text_encoder, placeholder_token_id, save_path) - - -training_function() - -for param in itertools.chain(unet.parameters(), text_encoder.parameters()): - if param.grad is not None: - del param.grad # free some memory - torch.cuda.empty_cache() - -# Set up the pipeline -from diffusers import DPMSolverMultistepScheduler - -pipe = StableDiffusionPipeline.from_pretrained( - hyperparameters["output_dir"], - scheduler=DPMSolverMultistepScheduler.from_pretrained( - hyperparameters["output_dir"], subfolder="scheduler" - ), -) -if not args.use_torchdynamo: - pipe.to(args.device) - -# Run the Stable Diffusion pipeline -# Don't forget to use the placeholder token in your prompt - -all_images = [] -for _ in range(args.num_inference_samples): - images = pipe( - [args.prompt], - num_inference_steps=args.inference_steps, - guidance_scale=7.5, - ).images - all_images.extend(images) - -output_path = os.path.abspath(os.path.join(os.getcwd(), args.output_dir)) -if not os.path.isdir(args.output_dir): - os.mkdir(args.output_dir) - -[ - image.save(f"{args.output_dir}/{i}.jpeg") - for i, image in enumerate(all_images) -] diff --git a/shark/iree_eager_backend.py b/shark/iree_eager_backend.py deleted file mode 100644 index 6be9cbf280..0000000000 --- a/shark/iree_eager_backend.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright 2020 The Nod Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Dict, Any - -import iree -import iree.runtime as ireert -import numpy as np -import torch -from iree.runtime import DeviceArray -from torch_mlir._mlir_libs._mlir.ir import Module -from torch_mlir.compiler_utils import ( - run_pipeline_with_repro_report, -) -from torch_mlir.eager_mode.torch_mlir_eager_backend import ( - TorchMLIREagerBackend, - TensorMetaData, -) -from torch_mlir_e2e_test.eager_backends.refbackend import ( - NUMPY_TO_TORCH_DTYPE_DICT, -) - -from shark.iree_utils.compile_utils import ( - get_iree_compiled_module, - IREE_DEVICE_MAP, -) - - -class EagerModeIREELinalgOnTensorsBackend(TorchMLIREagerBackend): - """Main entry-point for the iree backend for torch-mlir eager mode. - - EagerModeIREELinalgOnTensorsBackend uses iree.DeviceArray representations of tensors and - thus all of the wrapping and unwrapping and munging here is done to between torch.Tensor and iree.DeviceArray, - with np.ndarray as an intermediary. - """ - - def __init__(self, device: str): - self.torch_device_str = device - self.config = ireert.Config(IREE_DEVICE_MAP[device]) - self.raw_device_str = device - - def get_torch_metadata( - self, tensor: DeviceArray, kwargs: Dict[str, Any] - ) -> TensorMetaData: - return TensorMetaData( - size=tensor.shape, - dtype=NUMPY_TO_TORCH_DTYPE_DICT[tensor.dtype.type], - device=torch.device(self.torch_device_str), - requires_grad=tensor.dtype.type - in {np.float, np.float32, np.float64} - and kwargs.get("requires_grad", False), - ) - - def compile(self, imported_module: Module): - run_pipeline_with_repro_report( - imported_module, - "torch-function-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline", - "EagerMode", - ) - callable, _ = get_iree_compiled_module( - imported_module, self.raw_device_str - ) - return callable - - def copy_into(self, dst, src): - """Copy output back to appropriate arg that it should alias.""" - np.copyto(dst, src) - - def transfer_from_device_to_torch(self, e): - return torch.from_numpy(e.to_host()) - - def transfer_from_torch_to_device( - self, tensor: torch.Tensor - ) -> DeviceArray: - return iree.runtime.asdevicearray(self.config.device, tensor.numpy()) diff --git a/shark/iree_utils/__init__.py b/shark/iree_utils/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/shark/iree_utils/_common.py b/shark/iree_utils/_common.py deleted file mode 100644 index 1d022f67e4..0000000000 --- a/shark/iree_utils/_common.py +++ /dev/null @@ -1,164 +0,0 @@ -# Copyright 2023 The Nod Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -## Common utilities to be shared by iree utilities. -import functools -import os -import sys -import subprocess - - -def run_cmd(cmd, debug=False, raise_err=False): - """ - Inputs: - cmd : cli command string. - debug : if True, prints debug info - raise_err : if True, raise exception to caller - """ - if debug: - print("IREE run command: \n\n") - print(cmd) - print("\n\n") - try: - result = subprocess.run( - cmd, - shell=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - check=True, - ) - stdout = result.stdout.decode() - stderr = result.stderr.decode() - return stdout, stderr - except subprocess.CalledProcessError as e: - if raise_err: - raise Exception from e - else: - print(e.output) - sys.exit(f"Exiting program due to error running {cmd}") - - -def iree_device_map(device): - uri_parts = device.split("://", 2) - iree_driver = ( - _IREE_DEVICE_MAP[uri_parts[0]] - if uri_parts[0] in _IREE_DEVICE_MAP - else uri_parts[0] - ) - if len(uri_parts) == 1: - return iree_driver - elif "rocm" in uri_parts: - return "rocm" - else: - return f"{iree_driver}://{uri_parts[1]}" - - -def get_supported_device_list(): - return list(_IREE_DEVICE_MAP.keys()) - - -_IREE_DEVICE_MAP = { - "cpu": "local-task", - "cpu-task": "local-task", - "cpu-sync": "local-sync", - "cuda": "cuda", - "vulkan": "vulkan", - "metal": "metal", - "rocm": "rocm", - "hip": "hip", - "intel-gpu": "level_zero", -} - - -def iree_target_map(device): - if "://" in device: - device = device.split("://")[0] - return _IREE_TARGET_MAP[device] if device in _IREE_TARGET_MAP else device - - -_IREE_TARGET_MAP = { - "cpu": "llvm-cpu", - "cpu-task": "llvm-cpu", - "cpu-sync": "llvm-cpu", - "cuda": "cuda", - "vulkan": "vulkan-spirv", - "metal": "metal", - "rocm": "rocm", - "hip": "rocm", - "intel-gpu": "opencl-spirv", -} - - -# Finds whether the required drivers are installed for the given device. -@functools.cache -def check_device_drivers(device): - """ - Checks necessary drivers present for gpu and vulkan devices - False => drivers present! - """ - if "://" in device: - device = device.split("://")[0] - - from iree.runtime import get_driver - - device_mapped = iree_device_map(device) - - try: - _ = get_driver(device_mapped) - except ValueError as ve: - print( - f"[ERR] device `{device}` not registered with IREE. " - "Ensure IREE is configured for use with this device.\n" - f"Full Error: \n {repr(ve)}" - ) - return True - except RuntimeError as re: - print(f"[ERR] Failed to get driver for {device} with error:\n{repr(re)}") - return True - - # Unknown device. We assume drivers are installed. - return False - - -# Installation info for the missing device drivers. -def device_driver_info(device): - device_driver_err_map = { - "cuda": { - "debug": "Try `nvidia-smi` on system to check.", - "solution": " from https://www.nvidia.in/Download/index.aspx?lang=en-in for your system.", - }, - "vulkan": { - "debug": "Try `vulkaninfo` on system to check.", - "solution": " from https://vulkan.lunarg.com/sdk/home for your distribution.", - }, - "metal": { - "debug": "Check if Bare metal is supported and enabled on your system.", - "solution": ".", - }, - "rocm": { - "debug": f"Try `{'hip' if sys.platform == 'win32' else 'rocm'}info` on system to check.", - "solution": " from https://rocm.docs.amd.com/en/latest/rocm.html for your system.", - }, - } - - if device in device_driver_err_map: - err_msg = ( - f"Required drivers for {device} not found. {device_driver_err_map[device]['debug']} " - f"Please install the required drivers{device_driver_err_map[device]['solution']} " - f"For further assistance please reach out to the community on discord [https://discord.com/invite/RUqY2h2s9u]" - f" and/or file a bug at https://github.com/nod-ai/SHARK/issues" - ) - return err_msg - else: - return f"{device} is not supported." diff --git a/shark/iree_utils/benchmark_utils.py b/shark/iree_utils/benchmark_utils.py deleted file mode 100644 index f32380e84b..0000000000 --- a/shark/iree_utils/benchmark_utils.py +++ /dev/null @@ -1,154 +0,0 @@ -# Copyright 2020 The Nod Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from shark.iree_utils._common import run_cmd, iree_device_map -from shark.iree_utils.cpu_utils import get_cpu_count -import numpy as np -import os -import re -import platform - -UNIT_TO_SECOND_MAP = {"us": 1e-6, "ms": 0.001, "s": 1} - - -def tensor_to_type_str(input_tensors: tuple, mlir_dialect: str): - """ - Input: A tuple of input tensors i.e tuple(torch.tensor) - Output: list of string that represent mlir types (i.e 1x24xf64) - # TODO: Support more than floats, and ints - """ - list_of_type = [] - for input_tensor in input_tensors: - type_string = "x".join([str(dim) for dim in input_tensor.shape]) - if mlir_dialect in ["linalg", "tosa"]: - dtype_string = str(input_tensor.dtype).replace("torch.", "") - elif mlir_dialect in ["mhlo", "tflite"]: - dtype = input_tensor.dtype - try: - dtype_string = re.findall("'[^\"]*'", str(dtype))[0].replace( - "'", "" - ) - except IndexError: - dtype_string = str(dtype) - regex_split = re.compile("([a-zA-Z]+)([0-9]+)") - match = regex_split.match(dtype_string) - mlir_type_string = str(match.group(1)[0]) + str(match.group(2)) - type_string += f"x{mlir_type_string}" - list_of_type.append(type_string) - return list_of_type - - -def build_benchmark_args( - input_file: str, - device: str, - input_tensors: tuple, - mlir_dialect: str, - training=False, -): - """ - Inputs: input_file leading to vmfb, input_tensor to function, target device, - and whether it is training or not. - Outputs: string that execute benchmark-module on target model. - """ - path = os.path.join(os.environ["VIRTUAL_ENV"], "bin") - if platform.system() == "Windows": - benchmarker_path = os.path.join(path, "iree-benchmark-module.exe") - time_extractor = None - else: - benchmarker_path = os.path.join(path, "iree-benchmark-module") - time_extractor = "| awk 'END{{print $2 $3}}'" - benchmark_cl = [benchmarker_path, f"--module={input_file}"] - # TODO: The function named can be passed as one of the args. - fn_name = "forward" - if training == True: - # TODO: Replace name of train with actual train fn name. - fn_name = "train" - benchmark_cl.append(f"--function={fn_name}") - benchmark_cl.append(f"--device={iree_device_map(device)}") - mlir_input_types = tensor_to_type_str(input_tensors, mlir_dialect) - for mlir_input in mlir_input_types: - benchmark_cl.append(f"--input={mlir_input}") - if device == "cpu": - num_cpus = get_cpu_count() - if num_cpus is not None: - benchmark_cl.append(f"--task_topology_max_group_count={num_cpus}") - # if time_extractor: - # benchmark_cl.append(time_extractor) - benchmark_cl.append(f"--print_statistics=true") - return benchmark_cl - - -def build_benchmark_args_non_tensor_input( - input_file: str, - device: str, - inputs: tuple, - mlir_dialect: str, - function_name: str, -): - """ - Inputs: input_file leading to vmfb, input_tensor to function, target device, - and whether it is training or not. - Outputs: string that execute benchmark-module on target model. - """ - path = os.path.join(os.environ["VIRTUAL_ENV"], "bin") - if platform.system() == "Windows": - benchmarker_path = os.path.join(path, "iree-benchmark-module.exe") - time_extractor = None - else: - benchmarker_path = os.path.join(path, "iree-benchmark-module") - time_extractor = "| awk 'END{{print $2 $3}}'" - benchmark_cl = [benchmarker_path, f"--module={input_file}"] - # TODO: The function named can be passed as one of the args. - if function_name: - benchmark_cl.append(f"--function={function_name}") - benchmark_cl.append(f"--device={iree_device_map(device)}") - for input in inputs: - benchmark_cl.append(f"--input={input}") - if platform.system() != "Windows": - time_extractor = "| awk 'END{{print $2 $3}}'" - benchmark_cl.append(time_extractor) - return benchmark_cl - - -def run_benchmark_module(benchmark_cl): - """ - Run benchmark command, extract result and return iteration/seconds, host - peak memory, and device peak memory. - - # TODO: Add an example of the benchmark command. - Input: benchmark command. - """ - benchmark_path = benchmark_cl[0] - assert os.path.exists( - benchmark_path - ), "Cannot find iree_benchmark_module, Please contact SHARK maintainer on discord." - bench_stdout, bench_stderr = run_cmd(" ".join(benchmark_cl)) - try: - regex_split = re.compile("(\d+[.]*\d*)( *)([a-zA-Z]+)") - match = regex_split.search(bench_stdout) - time_ms = float(match.group(1)) - unit = match.group(3) - except AttributeError: - regex_split = re.compile("(\d+[.]*\d*)([a-zA-Z]+)") - match = regex_split.search(bench_stdout) - time_ms = float(match.group(1)) - unit = match.group(2) - iter_per_second = 1.0 / (time_ms * 0.001) - - # Extract peak memory. - host_regex = re.compile(r".*HOST_LOCAL:\s*([0-9]+)B peak") - host_peak_b = int(host_regex.search(bench_stderr).group(1)) - device_regex = re.compile(r".*DEVICE_LOCAL:\s*([0-9]+)B peak") - device_peak_b = int(device_regex.search(bench_stderr).group(1)) - return iter_per_second, host_peak_b, device_peak_b diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py deleted file mode 100644 index 208a823fa0..0000000000 --- a/shark/iree_utils/compile_utils.py +++ /dev/null @@ -1,704 +0,0 @@ -# Copyright 2023 The Nod Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import functools -import numpy as np -import os -import re -import tempfile -from pathlib import Path - -#import iree.runtime as ireert -import iree.compiler as ireec -from shark.parser import shark_args - -from .trace import DetailLogger -from ._common import iree_device_map, iree_target_map -from .cpu_utils import get_iree_cpu_rt_args -from .benchmark_utils import * - - -# Get the iree-compile arguments given device. -def get_iree_device_args(device, extra_args=[]): - print("Configuring for device:" + device) - device, device_num = clean_device_info(device) - - if "cpu" in device: - from shark.iree_utils.cpu_utils import get_iree_cpu_args - - u_kernel_flag = ["--iree-llvmcpu-enable-ukernels"] - stack_size_flag = ["--iree-llvmcpu-stack-allocation-limit=256000"] - - return ( - get_iree_cpu_args() - + u_kernel_flag - + stack_size_flag - ) - if device == "cuda": - from shark.iree_utils.gpu_utils import get_iree_gpu_args - - return get_iree_gpu_args() - if device == "vulkan": - from shark.iree_utils.vulkan_utils import get_iree_vulkan_args - - return get_iree_vulkan_args( - device_num=device_num, extra_args=extra_args - ) - if device == "metal": - from shark.iree_utils.metal_utils import get_iree_metal_args - - return get_iree_metal_args(extra_args=extra_args) - if device == "rocm": - from shark.iree_utils.gpu_utils import get_iree_rocm_args - - return get_iree_rocm_args(device_num=device_num, extra_args=extra_args) - if device == "hip": - from shark.iree_utils.gpu_utils import get_iree_rocm_args - return get_iree_rocm_args(device_num=device_num, extra_args=extra_args, hip_driver=True) - return [] - -def get_iree_target_triple(device): - args = get_iree_device_args(device) - for flag in args: - if "triple" in flag: - triple = flag.split("=")[-1] - return triple - return "" - - -def clean_device_info(raw_device): - # return appropriate device and device_id for consumption by Studio pipeline - # Multiple devices only supported for vulkan and rocm (as of now). - # default device must be selected for all others - - device_id = None - device = ( - raw_device - if "=>" not in raw_device - else raw_device.split("=>")[1].strip() - ) - if "://" in device: - device, device_id = device.split("://") - if len(device_id) <= 2: - device_id = int(device_id) - - if device not in ["hip", "rocm", "vulkan"]: - device_id = None - if device in ["hip", "rocm", "vulkan"] and device_id == None: - device_id = 0 - return device, device_id - - -# Get the iree-compiler arguments given frontend. -def get_iree_frontend_args(frontend): - if frontend in ["torch", "pytorch", "linalg", "tm_tensor"]: - return ["--iree-llvmcpu-target-cpu-features=host"] - elif frontend in ["tensorflow", "tf", "mhlo", "stablehlo"]: - return [ - "--iree-llvmcpu-target-cpu-features=host", - "--iree-input-demote-i64-to-i32", - ] - else: - # Frontend not found. - return [] - - -# Common args to be used given any frontend or device. -def get_iree_common_args(debug=False): - common_args = [ - "--iree-util-zero-fill-elided-attrs", - "--mlir-elide-elementsattrs-if-larger=10", - ] - if debug == True: - common_args.extend( - [ - "--iree-opt-strip-assertions=false", - "--verify=true", - ] - ) - else: - common_args.extend( - [ - "--iree-opt-strip-assertions=true", - "--verify=false", - ] - ) - return common_args - - -# Args that are suitable only for certain models or groups of models. -# shark_args are passed down from pytests to control which models compile with these flags, -# but they can also be set in shark/parser.py -def get_model_specific_args(): - ms_args = [] - if shark_args.enable_conv_transform == True: - ms_args += [ - "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-convert-conv-nchw-to-nhwc))" - ] - if shark_args.enable_img2col_transform == True: - ms_args += [ - "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-convert-conv2d-to-img2col))" - ] - if shark_args.use_winograd == True: - ms_args += [ - "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-linalg-ext-convert-conv2d-to-winograd))" - ] - return ms_args - - -def create_dispatch_dirs(bench_dir, device): - protected_files = ["ordered-dispatches.txt"] - bench_dir_path = bench_dir.split("/") - bench_dir_path[-1] = "temp_" + bench_dir_path[-1] - tmp_bench_dir = "/".join(bench_dir_path) - for f_ in os.listdir(bench_dir): - if os.path.isfile(f"{bench_dir}/{f_}") and f_ not in protected_files: - dir_name = re.sub("\.\S*$", "", f_) - if os.path.exists(f"{bench_dir}/{dir_name}"): - os.system(f"rm -rf {bench_dir}/{dir_name}") - os.system(f"mkdir {bench_dir}/{dir_name}") - os.system(f"mv {bench_dir}/{f_} {bench_dir}/{dir_name}/{f_}") - for f_ in os.listdir(tmp_bench_dir): - if os.path.isfile(f"{tmp_bench_dir}/{f_}"): - dir_name = "" - for d_ in os.listdir(bench_dir): - if re.search(f"{d_}(?=\D)", f_): - dir_name = d_ - if dir_name != "": - os.system( - f"mv {tmp_bench_dir}/{f_} {bench_dir}/{dir_name}/{dir_name}_benchmark.mlir" - ) - - -def dump_isas(bench_dir): - for d_ in os.listdir(bench_dir): - if os.path.isdir(f"{bench_dir}/{d_}"): - for f_ in os.listdir(f"{bench_dir}/{d_}"): - if f_.endswith(".spv"): - os.system( - f"amdllpc -gfxip 11.0 {bench_dir}/{d_}/{f_} -v > \ - {bench_dir}/{d_}/isa.txt" - ) - - -def compile_benchmark_dirs(bench_dir, device, dispatch_benchmarks): - benchmark_runtimes = {} - dispatch_list = [] - all_dispatches = False - - if dispatch_benchmarks.lower().strip() == "all": - all_dispatches = True - else: - try: - dispatch_list = [ - int(dispatch_index) - for dispatch_index in dispatch_benchmarks.split(" ") - ] - except: - print("ERROR: Invalid dispatch benchmarks") - return None - for d_ in os.listdir(bench_dir): - if os.path.isdir(f"{bench_dir}/{d_}"): - in_dispatches = False - for dispatch in dispatch_list: - if str(dispatch) in d_: - in_dispatches = True - if all_dispatches or in_dispatches: - for f_ in os.listdir(f"{bench_dir}/{d_}"): - if "benchmark.mlir" in f_: - dispatch_file = open(f"{bench_dir}/{d_}/{f_}", "r") - module = dispatch_file.read() - dispatch_file.close() - - flatbuffer_blob = ireec.compile_str( - module, target_backends=[iree_target_map(device)] - ) - - vmfb_file = open( - f"{bench_dir}/{d_}/{d_}_benchmark.vmfb", "wb" - ) - vmfb_file.write(flatbuffer_blob) - vmfb_file.close() - - config = get_iree_runtime_config(device) - vm_module = ireert.VmModule.from_buffer( - config.vm_instance, - flatbuffer_blob, - warn_if_copy=False, - ) - - benchmark_cl = build_benchmark_args_non_tensor_input( - input_file=f"{bench_dir}/{d_}/{d_}_benchmark.vmfb", - device=device, - inputs=(0,), - mlir_dialect="linalg", - function_name="", - ) - - benchmark_bash = open( - f"{bench_dir}/{d_}/{d_}_benchmark.sh", "w+" - ) - benchmark_bash.write("#!/bin/bash\n") - benchmark_bash.write(" ".join(benchmark_cl)) - benchmark_bash.close() - - iter_per_second, _, _ = run_benchmark_module( - benchmark_cl - ) - - benchmark_file = open( - f"{bench_dir}/{d_}/{d_}_data.txt", "w+" - ) - benchmark_file.write(f"DISPATCH: {d_}\n") - benchmark_file.write(str(iter_per_second) + "\n") - benchmark_file.write( - "SHARK BENCHMARK RESULT: " - + str(1 / (iter_per_second * 0.001)) - + "\n" - ) - benchmark_file.close() - - benchmark_runtimes[d_] = 1 / (iter_per_second * 0.001) - - elif ".mlir" in f_ and "benchmark" not in f_: - dispatch_file = open(f"{bench_dir}/{d_}/{f_}", "r") - module = dispatch_file.read() - dispatch_file.close() - - module = re.sub( - "hal.executable private", - "hal.executable public", - module, - ) - - flatbuffer_blob = ireec.compile_str( - module, - target_backends=[iree_target_map(device)], - extra_args=["--compile-mode=hal-executable"], - ) - - spirv_file = open( - f"{bench_dir}/{d_}/{d_}_spirv.vmfb", "wb" - ) - spirv_file.write(flatbuffer_blob) - spirv_file.close() - - ordered_dispatches = [ - (k, v) - for k, v in sorted( - benchmark_runtimes.items(), key=lambda item: item[1] - ) - ][::-1] - f_ = open(f"{bench_dir}/ordered-dispatches.txt", "w+") - for dispatch in ordered_dispatches: - f_.write(f"{dispatch[0]}: {dispatch[1]}ms\n") - f_.close() - - -def compile_module_to_flatbuffer( - module, - device, - frontend, - model_config_path, - extra_args, - model_name="None", - debug=False, - compile_str=False, - write_to=None, -): - # Setup Compile arguments wrt to frontends. - input_type = "auto" - args = get_iree_frontend_args(frontend) - args += get_iree_device_args(device, extra_args) - args += get_iree_common_args(debug=debug) - args += get_model_specific_args() - args += extra_args - args += shark_args.additional_compile_args - - if frontend in ["tensorflow", "tf"]: - input_type = "auto" - elif frontend in ["stablehlo", "tosa"]: - input_type = frontend - elif frontend in ["tflite", "tflite-tosa"]: - input_type = "tosa" - elif frontend in ["tm_tensor"]: - input_type = ireec.InputType.TM_TENSOR - elif frontend in ["torch", "pytorch"]: - input_type = "torch" - - if compile_str: - flatbuffer_blob = ireec.compile_str( - module, - target_backends=[iree_target_map(device)], - extra_args=args, - input_type=input_type, - ) - else: - assert os.path.isfile(module) - flatbuffer_blob = ireec.compile_file( - str(module), - input_type=input_type, - target_backends=[iree_target_map(device)], - extra_args=args, - ) - - if write_to is not None: - with open(write_to, "wb") as f: - f.write(flatbuffer_blob) - return None - - return flatbuffer_blob - - -def get_iree_module( - flatbuffer_blob, - device, - device_idx=None, - rt_flags: list = [], - external_weight_file=None, -): - if external_weight_file is not None: - index = ireert.ParameterIndex() - index.load(external_weight_file) - # Returns the compiled module and the configs. - for flag in rt_flags: - ireert.flags.parse_flag(flag) - if device_idx is not None: - device = iree_device_map(device) - print("registering device id: ", device_idx) - haldriver = ireert.get_driver(device) - hal_device_id = haldriver.query_available_devices()[device_idx][ - "device_id" - ] - haldevice = haldriver.create_device( - hal_device_id, - allocators=shark_args.device_allocator, - ) - config = ireert.Config(device=haldevice) - config.id = hal_device_id - else: - config = get_iree_runtime_config(device) - vm_module = ireert.VmModule.from_buffer( - config.vm_instance, flatbuffer_blob, warn_if_copy=False - ) - modules = [] - if external_weight_file is not None: - modules.append(index.create_provider(scope="model")) - ctx = ireert.SystemContext(vm_modules=modules, config=config) - ctx.add_vm_module(vm_module) - ModuleCompiled = getattr(ctx.modules, vm_module.name) - return ModuleCompiled, config - - -def load_vmfb_using_mmap( - flatbuffer_blob_or_path, - device: str, - device_idx: int = None, - rt_flags: list = [], - external_weight_file: str = None, -): - print(f"Loading module {flatbuffer_blob_or_path}...") - if "task" in device: - print( - f"[DEBUG] setting iree runtime flags for cpu:\n{' '.join(get_iree_cpu_rt_args())}" - ) - for flag in get_iree_cpu_rt_args(): - rt_flags.append(flag) - for flag in rt_flags: - print(flag) - ireert.flags.parse_flags(flag) - - if "rocm" in device: - device = "rocm" - with DetailLogger(timeout=2.5) as dl: - # First get configs. - if device_idx is not None: - dl.log(f"Mapping device id: {device_idx}") - device = iree_device_map(device) - haldriver = ireert.get_driver(device) - dl.log(f"ireert.get_driver()") - - hal_device_id = haldriver.query_available_devices()[device_idx][ - "device_id" - ] - haldevice = haldriver.create_device( - hal_device_id, - allocators=shark_args.device_allocator, - ) - dl.log(f"ireert.create_device()") - config = ireert.Config(device=haldevice) - config.id = hal_device_id - dl.log(f"ireert.Config()") - else: - config = get_iree_runtime_config(device) - dl.log("get_iree_runtime_config") - if "task" in device: - print( - f"[DEBUG] setting iree runtime flags for cpu:\n{' '.join(get_iree_cpu_rt_args())}" - ) - for flag in get_iree_cpu_rt_args(): - ireert.flags.parse_flags(flag) - - # Now load vmfb. - # Two scenarios we have here :- - # 1. We either have the vmfb already saved and therefore pass the path of it. - # (This would arise if we're invoking `load_module` from a SharkInference obj) - # OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with. - # (This would arise if we're invoking `compile` from a SharkInference obj) - temp_file_to_unlink = None - if isinstance(flatbuffer_blob_or_path, Path): - flatbuffer_blob_or_path = flatbuffer_blob_or_path.__str__() - if ( - isinstance(flatbuffer_blob_or_path, str) - and ".vmfb" in flatbuffer_blob_or_path - ): - vmfb_file_path = flatbuffer_blob_or_path - mmaped_vmfb = ireert.VmModule.mmap( - config.vm_instance, flatbuffer_blob_or_path - ) - vm_modules = [] - if external_weight_file is not None: - index = ireert.ParameterIndex() - index.load(external_weight_file) - param_module = ireert.create_io_parameters_module( - config.vm_instance, index.create_provider(scope="model") - ) - vm_modules.append(param_module) - vm_modules.append(mmaped_vmfb) - vm_modules.append( - ireert.create_hal_module(config.vm_instance, config.device) - ) - dl.log(f"mmap {flatbuffer_blob_or_path}") - if "vulkan" in device: - # Vulkan pipeline creation consumes significant amount of time. - print( - "\tCompiling Vulkan shaders. This may take a few minutes." - ) - ctx = ireert.SystemContext(config=config, vm_modules=vm_modules) - dl.log(f"ireert.SystemContext created") - for flag in shark_args.additional_runtime_args: - ireert.flags.parse_flags(flag) - dl.log(f"module initialized") - mmaped_vmfb = getattr(ctx.modules, mmaped_vmfb.name) - else: - with tempfile.NamedTemporaryFile(delete=False) as tf: - tf.write(flatbuffer_blob_or_path) - tf.flush() - vmfb_file_path = tf.name - temp_file_to_unlink = vmfb_file_path - mmaped_vmfb = ireert.VmModule.mmap(instance, vmfb_file_path) - dl.log(f"mmap temp {vmfb_file_path}") - return mmaped_vmfb, config, temp_file_to_unlink - - -def get_iree_compiled_module( - module, - device: str, - frontend: str = "torch", - model_config_path: str = None, - extra_args: list = [], - rt_flags: list = [], - device_idx: int = None, - mmap: bool = False, - debug: bool = False, - compile_str: bool = False, - external_weight_file: str = None, - write_to: bool = None, -): - """Given a module returns the compiled .vmfb and configs""" - flatbuffer_blob = compile_module_to_flatbuffer( - module=module, - device=device, - frontend=frontend, - model_config_path=model_config_path, - extra_args=extra_args, - debug=debug, - compile_str=compile_str, - write_to=write_to, - ) - temp_file_to_unlink = None - # TODO: Currently mmap=True control flow path has been switched off for mmap. - # Got to find a cleaner way to unlink/delete the temporary file since - # we're setting delete=False when creating NamedTemporaryFile. That's why - # I'm getting hold of the name of the temporary file in `temp_file_to_unlink`. - if mmap: - if write_to is not None: - flatbuffer_blob = write_to - vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap( - flatbuffer_blob, - device, - device_idx, - rt_flags, - external_weight_file=external_weight_file, - ) - else: - vmfb, config = get_iree_module( - flatbuffer_blob, - device, - device_idx=device_idx, - rt_flags=rt_flags, - external_weight_file=external_weight_file, - ) - ret_params = { - "vmfb": vmfb, - "config": config, - "temp_file_to_unlink": temp_file_to_unlink, - } - return ret_params - - -def load_flatbuffer( - flatbuffer_path: str, - device: str, - device_idx: int = None, - mmap: bool = False, - rt_flags: list = [], -): - temp_file_to_unlink = None - if mmap: - vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap( - flatbuffer_path, device, device_idx, rt_flags - ) - else: - with open(os.path.join(flatbuffer_path), "rb") as f: - flatbuffer_blob = f.read() - vmfb, config = get_iree_module( - flatbuffer_blob, - device, - device_idx=device_idx, - rt_flags=rt_flags, - ) - ret_params = { - "vmfb": vmfb, - "config": config, - "temp_file_to_unlink": temp_file_to_unlink, - } - return ret_params - - -def export_iree_module_to_vmfb( - module, - device: str, - directory: str, - mlir_dialect: str = "linalg", - model_config_path: str = None, - module_name: str = None, - extra_args: list = [], - debug: bool = False, - compile_str: bool = False, -): - # Compiles the module given specs and saves it as .vmfb file. - flatbuffer_blob = compile_module_to_flatbuffer( - module=module, - device=device, - frontend=mlir_dialect, - model_config_path=model_config_path, - extra_args=extra_args, - debug=debug, - compile_str=compile_str, - ) - if module_name is None: - device_name = ( - device if "://" not in device else "-".join(device.split("://")) - ) - module_name = f"{mlir_dialect}_{device_name}" - filename = os.path.join(directory, module_name + ".vmfb") - with open(filename, "wb") as f: - f.write(flatbuffer_blob) - print(f"Saved vmfb in {filename}.") - return filename - - -def export_module_to_mlir_file(module, frontend, directory: str): - # TODO: write proper documentation. - mlir_str = module - if frontend in ["tensorflow", "tf", "mhlo", "stablehlo", "tflite"]: - mlir_str = module.decode("utf-8") - elif frontend in ["pytorch", "torch"]: - mlir_str = module.operation.get_asm() - filename = os.path.join(directory, "model.mlir") - with open(filename, "w") as f: - f.write(mlir_str) - print(f"Saved mlir in {filename}.") - return filename - - -def get_results( - compiled_vm, - function_name, - input, - config, - frontend="torch", - send_to_host=True, - debug_timeout: float = 5.0, - device: str = None, -): - """Runs a .vmfb file given inputs and config and returns output.""" - with DetailLogger(debug_timeout) as dl: - device_inputs = [] - if device == "rocm" and hasattr(config, "id"): - haldriver = ireert.get_driver("rocm") - haldevice = haldriver.create_device( - config.id, - allocators=shark_args.device_allocator, - ) - for input_array in input: - dl.log(f"Load to device: {input_array.shape}") - device_inputs.append( - ireert.asdevicearray(config.device, input_array) - ) - dl.log(f"Invoke function: {function_name}") - result = compiled_vm[function_name](*device_inputs) - dl.log(f"Invoke complete") - result_tensors = [] - if isinstance(result, tuple): - if send_to_host: - for val in result: - dl.log(f"Result to host: {val.shape}") - result_tensors.append(np.asarray(val, val.dtype)) - else: - for val in result: - result_tensors.append(val) - return result_tensors - elif isinstance(result, dict): - data = list(result.items()) - if send_to_host: - res = np.array(data, dtype=object) - return np.copy(res) - return data - else: - if send_to_host and result is not None: - dl.log("Result to host") - return result.to_host() - return result - dl.log("Execution complete") - - -# @functools.cache -# def get_iree_runtime_config(device): -# device = iree_device_map(device) -# haldriver = ireert.get_driver(device) -# if "metal" in device and shark_args.device_allocator == "caching": -# print( -# "[WARNING] metal devices can not have a `caching` allocator." -# "\nUsing default allocator `None`" -# ) -# haldevice = haldriver.create_device_by_uri( -# device, -# # metal devices have a failure with caching allocators atm. blcking this util it gets fixed upstream. -# allocators=shark_args.device_allocator -# if "metal" not in device -# else None, -# ) -# config = ireert.Config(device=haldevice) -# return config diff --git a/shark/iree_utils/cpu_utils.py b/shark/iree_utils/cpu_utils.py deleted file mode 100644 index 8aca117093..0000000000 --- a/shark/iree_utils/cpu_utils.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright 2020 The Nod Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# All the iree_cpu related functionalities go here. - -import functools -import subprocess -import platform -from shark.parser import shark_args - - -def get_cpu_count(): - import multiprocessing - - try: - cpu_count = multiprocessing.cpu_count() - return cpu_count - except NotImplementedError: - return None - - -# Get the default cpu args. -@functools.cache -def get_iree_cpu_args(): - uname = platform.uname() - os_name, proc_name = uname.system, uname.machine - - if os_name == "Darwin": - kernel_version = uname.release - target_triple = f"{proc_name}-apple-darwin{kernel_version}" - elif os_name == "Linux": - target_triple = f"{proc_name}-linux-gnu" - elif os_name == "Windows": - target_triple = "x86_64-pc-windows-msvc" - else: - error_message = f"OS Type f{os_name} not supported and triple can't be determined, open issue to dSHARK team please :)" - raise Exception(error_message) - print(f"Target triple found:{target_triple}") - return [ - f"--iree-llvmcpu-target-triple={target_triple}", - ] - - -# Get iree runtime flags for cpu -@functools.cache -def get_iree_cpu_rt_args(): - default = get_cpu_count() - default = default if default <= 8 else default - 2 - cpu_count = ( - default - if shark_args.task_topology_max_group_count is None - else shark_args.task_topology_max_group_count - ) - return [f"--task_topology_max_group_count={cpu_count}"] diff --git a/shark/iree_utils/gpu_utils.py b/shark/iree_utils/gpu_utils.py deleted file mode 100644 index db6ef14e34..0000000000 --- a/shark/iree_utils/gpu_utils.py +++ /dev/null @@ -1,209 +0,0 @@ -# Copyright 2020 The Nod Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# All the iree_gpu related functionalities go here. - -import functools -import iree.runtime as ireert -import ctypes -import sys -from subprocess import CalledProcessError -from shark.parser import shark_args -from shark.iree_utils._common import run_cmd - -# TODO: refactor to rocm and cuda utils - - -# Get the default gpu args given the architecture. -@functools.cache -def get_iree_gpu_args(): - ireert.flags.FUNCTION_INPUT_VALIDATION = False - ireert.flags.parse_flags("--cuda_allow_inline_execution") - # TODO: Give the user_interface to pass the sm_arch. - sm_arch = get_cuda_sm_cc() - if ( - sm_arch - in ["sm_70", "sm_72", "sm_75", "sm_80", "sm_84", "sm_86", "sm_89"] - ) and (shark_args.enable_tf32 == True): - return [ - f"--iree-hal-cuda-llvm-target-arch={sm_arch}", - ] - else: - return [] - - -def check_rocm_device_arch_in_args(extra_args): - # Check if the target arch flag for rocm device present in extra_args - for flag in extra_args: - if "iree-rocm-target-chip" in flag: - flag_arch = flag.split("=")[1] - return flag_arch - return None - - -def get_rocm_device_arch(device_num=0, extra_args=[], hip_driver=False): - # ROCM Device Arch selection: - # 1 : User given device arch using `--iree-rocm-target-chip` flag - # 2 : Device arch from `iree-run-module --dump_devices=rocm` for device on index - # 3 : default arch : gfx1100 - - arch_in_flag = check_rocm_device_arch_in_args(extra_args) - if arch_in_flag is not None: - print( - f"User Specified rocm target device arch from flag : {arch_in_flag} will be used" - ) - return arch_in_flag - - arch_in_device_dump = None - - # get rocm arch from iree dump devices - def get_devices_info_from_dump(dump, driver): - from os import linesep - - if driver == "hip": - dump_clean = list( - filter( - lambda s: "AMD" in s, - dump.split(linesep), - ) - ) - else: - dump_clean = list( - filter( - lambda s: f"--device={driver}" in s or "gpu-arch-name:" in s, - dump.split(linesep), - ) - ) - arch_pairs = [ - ( - dump_clean[i].split("=")[1].strip(), - dump_clean[i + 1].split(":")[1].strip(), - ) - for i in range(0, len(dump_clean), 2) - ] - return arch_pairs - - dump_device_info = None - driver = "hip" if hip_driver else "rocm" - try: - dump_device_info = run_cmd( - "iree-run-module --dump_devices=" + driver, raise_err=True - ) - except Exception as e: - print("could not execute `iree-run-module --dump_devices=" + driver + "`") - - if dump_device_info is not None: - device_num = 0 if device_num is None else device_num - device_arch_pairs = get_devices_info_from_dump(dump_device_info[0], driver) - if len(device_arch_pairs) > device_num: # can find arch in the list - arch_in_device_dump = device_arch_pairs[device_num][1] - - if arch_in_device_dump is not None: - print(f"Found ROCm device arch : {arch_in_device_dump}") - return arch_in_device_dump - - default_rocm_arch = "gfx1100" - print( - "Did not find ROCm architecture from `--iree-rocm-target-chip` flag" - "\n or from `iree-run-module --dump_devices` command." - f"\nUsing {default_rocm_arch} as ROCm arch for compilation." - ) - return default_rocm_arch - - -# Get the default gpu args given the architecture. -def get_iree_rocm_args(device_num=0, extra_args=[], hip_driver=False): - ireert.flags.FUNCTION_INPUT_VALIDATION = False - rocm_flags = [] - if check_rocm_device_arch_in_args(extra_args) is None: - rocm_arch = get_rocm_device_arch(device_num, extra_args, hip_driver=hip_driver) - rocm_flags.append(f"--iree-rocm-target-chip={rocm_arch}") - - return rocm_flags - -# Some constants taken from cuda.h -CUDA_SUCCESS = 0 -CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT = 16 -CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR = 39 -CU_DEVICE_ATTRIBUTE_CLOCK_RATE = 13 -CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE = 36 - - -@functools.cache -def get_cuda_sm_cc(): - libnames = ("libcuda.so", "libcuda.dylib", "nvcuda.dll") - for libname in libnames: - try: - cuda = ctypes.CDLL(libname) - except OSError: - continue - else: - break - else: - raise OSError("could not load any of: " + " ".join(libnames)) - - nGpus = ctypes.c_int() - name = b" " * 100 - cc_major = ctypes.c_int() - cc_minor = ctypes.c_int() - - result = ctypes.c_int() - device = ctypes.c_int() - context = ctypes.c_void_p() - error_str = ctypes.c_char_p() - - result = cuda.cuInit(0) - if result != CUDA_SUCCESS: - cuda.cuGetErrorString(result, ctypes.byref(error_str)) - print( - "cuInit failed with error code %d: %s" - % (result, error_str.value.decode()) - ) - return 1 - result = cuda.cuDeviceGetCount(ctypes.byref(nGpus)) - if result != CUDA_SUCCESS: - cuda.cuGetErrorString(result, ctypes.byref(error_str)) - print( - "cuDeviceGetCount failed with error code %d: %s" - % (result, error_str.value.decode()) - ) - return 1 - print("Found %d device(s)." % nGpus.value) - for i in range(nGpus.value): - result = cuda.cuDeviceGet(ctypes.byref(device), i) - if result != CUDA_SUCCESS: - cuda.cuGetErrorString(result, ctypes.byref(error_str)) - print( - "cuDeviceGet failed with error code %d: %s" - % (result, error_str.value.decode()) - ) - return 1 - print("Device: %d" % i) - if ( - cuda.cuDeviceGetName(ctypes.c_char_p(name), len(name), device) - == CUDA_SUCCESS - ): - print(" Name: %s" % (name.split(b"\0", 1)[0].decode())) - if ( - cuda.cuDeviceComputeCapability( - ctypes.byref(cc_major), ctypes.byref(cc_minor), device - ) - == CUDA_SUCCESS - ): - print( - " Compute Capability: %d.%d" - % (cc_major.value, cc_minor.value) - ) - sm = f"sm_{cc_major.value}{cc_minor.value}" - return sm diff --git a/shark/iree_utils/metal_utils.py b/shark/iree_utils/metal_utils.py deleted file mode 100644 index 21dfee8481..0000000000 --- a/shark/iree_utils/metal_utils.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright 2023 The Nod Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# All the iree_vulkan related functionalities go here. - -import functools - -from shark.iree_utils._common import run_cmd -import iree.runtime as ireert -from sys import platform -from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag - - -@functools.cache -def get_metal_device_name(device_num=0): - iree_device_dump = run_cmd("iree-run-module --dump_devices") - iree_device_dump = iree_device_dump[0].split("\n\n") - metal_device_list = [ - s.split("\n#")[2] for s in iree_device_dump if "--device=metal" in s - ] - if len(metal_device_list) == 0: - raise ValueError("No device name found in device dump!") - if len(metal_device_list) > 1: - print("Following devices found:") - for i, dname in enumerate(metal_device_list): - print(f"{i}. {dname}") - print(f"Choosing device: {metal_device_list[device_num]}") - return metal_device_list[device_num] - - -def get_os_name(): - if platform.startswith("linux"): - return "linux" - elif platform == "darwin": - return "macos" - elif platform == "win32": - return "windows" - else: - print("Cannot detect OS type, defaulting to linux.") - return "linux" - - -def get_metal_target_triple(device_name): - """This method provides a target triple str for specified vulkan device. - - Args: - device_name (str): name of the hardware device to be used with vulkan - - Returns: - str or None: target triple or None if no match found for given name - """ - return "macos" - - -def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]): - for flag in extra_args: - if "-iree-metal-target-platform=" in flag: - print(f"Using target triple {flag.split('=')[1]}") - return None - - if device_name == "" or device_name == [] or device_name is None: - metal_device = get_metal_device_name(device_num=device_num) - else: - metal_device = device_name - triple = get_metal_target_triple(metal_device) - if triple is not None: - print( - f"Found metal device {metal_device}. Using metal target platform {triple}" - ) - return f"-iree-metal-target-platform={triple}" - print( - """Optimized kernel for your target device is not added yet. - Contact SHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u] - or pull up an issue.""" - ) - print(f"Target : {metal_device}") - return None - - -def get_iree_metal_args(device_num=0, extra_args=[]): - # Add any metal spefic compilation flags here - res_metal_flag = [] - if len(extra_args) > 0: - res_metal_flag.extend(extra_args) - return res_metal_flag - - -def set_iree_metal_runtime_flags(flags): - for flag in flags: - ireert.flags.parse_flags(flag) - return diff --git a/shark/iree_utils/trace.py b/shark/iree_utils/trace.py deleted file mode 100644 index ea51243b87..0000000000 --- a/shark/iree_utils/trace.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright 2023 The Nod Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Tuple - -import os -import threading -import time - - -def _enable_detail_trace() -> bool: - return os.getenv("SHARK_DETAIL_TRACE", "0") == "1" - - -class DetailLogger: - """Context manager which can accumulate detailed log messages. - - Detailed log is only emitted if the operation takes a long time - or errors. - """ - - def __init__(self, timeout: float): - self._timeout = timeout - self._messages: List[Tuple[float, str]] = [] - self._start_time = time.time() - self._active = not _enable_detail_trace() - self._lock = threading.RLock() - self._cond = threading.Condition(self._lock) - self._thread = None - - def __enter__(self): - self._thread = threading.Thread(target=self._run) - self._thread.start() - return self - - def __exit__(self, type, value, traceback): - with self._lock: - self._active = False - self._cond.notify() - if traceback: - self.dump_on_error(f"exception") - - def _run(self): - with self._lock: - timed_out = not self._cond.wait(self._timeout) - if timed_out: - self.dump_on_error(f"took longer than {self._timeout}s") - - def log(self, msg): - with self._lock: - timestamp = time.time() - if self._active: - self._messages.append((timestamp, msg)) - else: - print(f" +{(timestamp - self._start_time) * 1000}ms: {msg}") - - def dump_on_error(self, summary: str): - with self._lock: - if self._active: - print(f"::: Detailed report ({summary}):") - for timestamp, msg in self._messages: - print( - f" +{(timestamp - self._start_time) * 1000}ms: {msg}" - ) - self._active = False diff --git a/shark/iree_utils/vulkan_target_env_utils.py b/shark/iree_utils/vulkan_target_env_utils.py deleted file mode 100644 index 7cd1b05241..0000000000 --- a/shark/iree_utils/vulkan_target_env_utils.py +++ /dev/null @@ -1,538 +0,0 @@ -# Copyright 2020 The Nod Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections import OrderedDict -import functools - - -@functools.cache -def get_vulkan_target_env(vulkan_target_triple): - arch, product, os = vulkan_target_triple.split("=")[1].split("-") - triple = (arch, product, os) - # get version - version = get_version(triple=triple) - # TODO get revision - revision = 120 - - # extensions - extensions = get_extensions(triple) - # get vendor - vendor = get_vendor(triple) - # get device type - device_type = get_device_type(triple) - # get capabilities - capabilities = get_vulkan_target_capabilities(triple) - target_env = f"<#spirv.vce<{version}, r({revision}), {extensions}>, {vendor}:{device_type}, #spirv.resource_limits< {capabilities} >>" - return target_env - - -def get_vulkan_target_env_flag(vulkan_target_triple): - target_env = get_vulkan_target_env(vulkan_target_triple) - target_env_flag = f"--iree-vulkan-target-env={target_env}" - return target_env_flag - - -def get_version(triple): - arch, product, os = triple - if os in ["android30", "android31"]: - return "v1.1" - if product in ["android30", "android31"]: - return "v1.1" - if arch in ["unknown"]: - return "v1.1" - return "v1.3" - - -@functools.cache -def get_extensions(triple): - def make_ext_list(ext_list): - res = ", ".join(ext_list) - return f"[{res}]" - - arch, product, os = triple - if arch == "m1": - ext = [ - "SPV_KHR_16bit_storage", - "SPV_KHR_8bit_storage", - "SPV_KHR_shader_float16_int8", - "SPV_KHR_storage_buffer_storage_class", - "SPV_KHR_variable_pointers", - ] - return make_ext_list(ext_list=ext) - - if arch == "valhall": - ext = [ - "SPV_KHR_16bit_storage", - "SPV_KHR_8bit_storage", - "SPV_KHR_shader_float16_int8", - "SPV_KHR_spirv_1_4", - "SPV_KHR_storage_buffer_storage_class", - "SPV_KHR_variable_pointers", - ] - return make_ext_list(ext_list=ext) - - if arch == "adreno": - ext = [ - "SPV_KHR_16bit_storage", - "SPV_KHR_shader_float16_int8", - "SPV_KHR_spirv_1_4", - "SPV_KHR_storage_buffer_storage_class", - "SPV_KHR_variable_pointers", - ] - if os == "android31": - ext.append("SPV_KHR_8bit_storage") - return make_ext_list(ext_list=ext) - - if get_vendor(triple) == "SwiftShader": - ext = ["SPV_KHR_storage_buffer_storage_class"] - return make_ext_list(ext_list=ext) - - if arch == "unknown": - ext = [ - "SPV_KHR_storage_buffer_storage_class", - "SPV_KHR_variable_pointers", - ] - return make_ext_list(ext_list=ext) - - ext = [ - "SPV_KHR_16bit_storage", - "SPV_KHR_8bit_storage", - "SPV_KHR_shader_float16_int8", - "SPV_KHR_spirv_1_4", - "SPV_KHR_storage_buffer_storage_class", - "SPV_KHR_variable_pointers", - "VK_EXT_subgroup_size_control", - ] - - if get_vendor(triple) == "NVIDIA" or arch == "rdna3": - ext.append("SPV_KHR_cooperative_matrix") - if get_vendor(triple) == ["NVIDIA", "AMD", "Intel"]: - ext.append("SPV_KHR_shader_integer_dot_product") - return make_ext_list(ext_list=ext) - - -@functools.cache -def get_vendor(triple): - arch, product, os = triple - if arch == "unknown": - return "Unknown" - if arch in ["rdna1", "rdna2", "rdna3", "rgcn3", "rgcn4", "rgcn5"]: - return "AMD" - if arch == "valhall": - return "ARM" - if arch == "m1": - return "Apple" - if arch in ["arc", "UHD"]: - return "Intel" - if arch in ["turing", "ampere", "pascal"]: - return "NVIDIA" - if arch == "adreno": - return "Qualcomm" - if arch == "cpu": - if product == "swiftshader": - return "SwiftShader" - return "Unknown" - print(f"Vendor for target triple - {triple} not found. Using unknown") - return "Unknown" - - -@functools.cache -def get_device_type(triple): - arch, product, _ = triple - if arch == "unknown": - return "Unknown" - if arch == "cpu": - return "CPU" - if arch in ["turing", "ampere", "arc", "pascal"]: - return "DiscreteGPU" - if arch in ["rdna1", "rdna2", "rdna3", "rgcn3", "rgcn5"]: - if product == "ivega10": - return "IntegratedGPU" - return "DiscreteGPU" - if arch in ["m1", "valhall", "adreno"]: - return "IntegratedGPU" - print(f"Device type for target triple - {triple} not found. Using unknown") - return "Unknown" - - -# get all the capabilities for the device -# TODO: make a dataclass for capabilites and init using vulkaninfo -@functools.cache -def get_vulkan_target_capabilities(triple): - def get_subgroup_val(l): - return int(sum([subgroup_feature[sgf] for sgf in l])) - - cap = OrderedDict() - arch, product, os = triple - subgroup_feature = { - "Basic": 1, - "Vote": 2, - "Arithmetic": 4, - "Ballot": 8, - "Shuffle": 16, - "ShuffleRelative": 32, - "Clustered": 64, - "Quad": 128, - "PartitionedNV": 256, - } - cap["max_compute_shared_memory_size"] = 16384 - cap["max_compute_workgroup_invocations"] = 128 - cap["max_compute_workgroup_size"] = [128, 128, 64] - cap["subgroup_size"] = 32 - cap["subgroupFeatures"] = ["Basic"] - cap["min_subgroup_size"] = None - cap["max_subgroup_size"] = None - cap["shaderFloat16"] = False - cap["shaderFloat64"] = False - cap["shaderInt8"] = False - cap["shaderInt16"] = False - cap["shaderInt64"] = False - cap["storageBuffer16BitAccess"] = False - cap["storagePushConstant16"] = False - cap["uniformAndStorageBuffer16BitAccess"] = False - cap["storageBuffer8BitAccess"] = False - cap["storagePushConstant8"] = False - cap["uniformAndStorageBuffer8BitAccess"] = False - cap["variablePointers"] = False - cap["variablePointersStorageBuffer"] = False - cap["coopmatCases"] = None - - if arch in ["rdna1", "rdna2", "rdna3"]: - cap["max_compute_shared_memory_size"] = 65536 - cap["max_compute_workgroup_invocations"] = 1024 - cap["max_compute_workgroup_size"] = [1024, 1024, 1024] - - cap["subgroup_size"] = 64 - cap["min_subgroup_size"] = 32 - cap["max_subgroup_size"] = 64 - cap["subgroupFeatures"] = [ - "Basic", - "Vote", - "Arithmetic", - "Ballot", - "Shuffle", - "ShuffleRelative", - "Clustered", - "Quad", - ] - - cap["shaderFloat16"] = True - cap["shaderFloat64"] = True - cap["shaderInt8"] = True - cap["shaderInt16"] = True - cap["shaderInt64"] = True - cap["shaderIntegerDotProduct"] = True - cap["storageBuffer16BitAccess"] = True - cap["storagePushConstant16"] = True - cap["uniformAndStorageBuffer16BitAccess"] = True - cap["storageBuffer8BitAccess"] = True - cap["storagePushConstant8"] = True - cap["uniformAndStorageBuffer8BitAccess"] = True - cap["variablePointers"] = True - cap["variablePointersStorageBuffer"] = True - if arch == "rdna3": - # TODO: Get scope value - cap["coopmatCases"] = [ - "m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = ", - "m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f32, result_type = f32, acc_sat = false, scope = " - ] - - if product == "rx5700xt": - cap["storagePushConstant16"] = False - cap["storagePushConstant8"] = False - - elif arch in ["rgcn5", "rgcn4", "rgcn3"]: - cap["max_compute_shared_memory_size"] = 65536 - cap["max_compute_workgroup_invocations"] = 1024 - cap["max_compute_workgroup_size"] = [1024, 1024, 1024] - - cap["subgroup_size"] = 64 - cap["subgroupFeatures"] = [ - "Basic", - "Vote", - "Arithmetic", - "Ballot", - "Shuffle", - "ShuffleRelative", - "Clustered", - "Quad", - ] - cap["min_subgroup_size"] = 64 - cap["max_subgroup_size"] = 64 - - if arch == "rgcn5": - cap["shaderFloat16"] = True - cap["shaderFloat64"] = True - - cap["storageBuffer16BitAccess"] = True - - cap["shaderInt8"] = True - cap["shaderInt16"] = True - cap["shaderInt64"] = True - cap["shaderIntegerDotProduct"] = True - cap["storagePushConstant16"] = False - cap["uniformAndStorageBuffer16BitAccess"] = True - cap["storageBuffer8BitAccess"] = True - cap["storagePushConstant8"] = False - cap["uniformAndStorageBuffer8BitAccess"] = True - - cap["variablePointers"] = True - cap["variablePointersStorageBuffer"] = True - - elif arch == "m1": - cap["max_compute_shared_memory_size"] = 32768 - cap["max_compute_workgroup_invocations"] = 1024 - cap["max_compute_workgroup_size"] = [1024, 1024, 1024] - - cap["subgroup_size"] = 32 - cap["subgroupFeatures"] = [ - "Basic", - "Vote", - "Arithmetic", - "Ballot", - "Shuffle", - "ShuffleRelative", - "Quad", - ] - - cap["shaderFloat16"] = True - cap["shaderFloat64"] = True - cap["shaderInt8"] = True - cap["shaderInt16"] = True - cap["shaderInt64"] = True - cap["shaderIntegerDotProduct"] = False - cap["storageBuffer16BitAccess"] = True - cap["storagePushConstant16"] = True - cap["uniformAndStorageBuffer16BitAccess"] = True - cap["storageBuffer8BitAccess"] = True - cap["storagePushConstant8"] = True - cap["uniformAndStorageBuffer8BitAccess"] = True - cap["variablePointers"] = True - cap["variablePointersStorageBuffer"] = True - - elif arch == "valhall": - cap["max_compute_shared_memory_size"] = 32768 - cap["max_compute_workgroup_invocations"] = 512 - cap["max_compute_workgroup_size"] = [512, 512, 512] - - cap["subgroup_size"] = 16 - cap["subgroupFeatures"] = [ - "Basic", - "Vote", - "Arithmetic", - "Ballot", - "Clustered", - "Quad", - ] - - if os == "android31": - cap["subgroupFeatures"].append("Shuffle") - cap["subgroupFeatures"].append("ShuffleRelative") - - cap["shaderFloat16"] = True - cap["shaderInt8"] = True - cap["shaderInt16"] = True - cap["storageBuffer16BitAccess"] = True - cap["storagePushConstant16"] = True - cap["uniformAndStorageBuffer16BitAccess"] = True - cap["storageBuffer8BitAccess"] = True - cap["storagePushConstant8"] = True - cap["uniformAndStorageBuffer8BitAccess"] = True - cap["variablePointers"] = True - cap["variablePointersStorageBuffer"] = True - - elif arch == "arc": - cap["max_compute_shared_memory_size"] = 32768 - cap["max_compute_workgroup_invocations"] = 1024 - cap["max_compute_workgroup_size"] = [1024, 1024, 64] - - cap["subgroup_size"] = 32 - cap["subgroupFeatures"] = [ - "Basic", - "Vote", - "Arithmetic", - "Ballot", - "Shuffle", - "ShuffleRelative", - "Clustered", - "Quad", - ] - - cap["shaderFloat16"] = True - cap["shaderFloat64"] = False - cap["shaderInt8"] = True - cap["shaderInt16"] = True - cap["shaderInt64"] = False - cap["shaderIntegerDotProduct"] = True - cap["storageBuffer16BitAccess"] = True - cap["storagePushConstant16"] = True - cap["uniformAndStorageBuffer16BitAccess"] = True - cap["storageBuffer8BitAccess"] = True - cap["storagePushConstant8"] = True - cap["uniformAndStorageBuffer8BitAccess"] = True - cap["variablePointers"] = True - cap["variablePointersStorageBuffer"] = True - - elif arch == "cpu": - if product == "swiftshader": - cap["max_compute_shared_memory_size"] = 16384 - cap["subgroup_size"] = 4 - cap["subgroupFeatures"] = [ - "Basic", - "Vote", - "Arithmetic", - "Ballot", - "Shuffle", - "ShuffleRelative", - ] - - elif arch in ["pascal"]: - cap["max_compute_shared_memory_size"] = 49152 - cap["max_compute_workgroup_invocations"] = 1536 - cap["max_compute_workgroup_size"] = [1536, 1024, 64] - - cap["subgroup_size"] = 32 - cap["min_subgroup_size"] = 32 - cap["max_subgroup_size"] = 32 - cap["subgroupFeatures"] = [ - "Basic", - "Vote", - "Arithmetic", - "Ballot", - "Shuffle", - "ShuffleRelative", - "Clustered", - "Quad", - ] - - cap["shaderFloat16"] = False - cap["shaderFloat64"] = True - cap["shaderInt8"] = True - cap["shaderInt16"] = True - cap["shaderInt64"] = True - cap["shaderIntegerDotProduct"] = True - cap["storageBuffer16BitAccess"] = True - cap["storagePushConstant16"] = True - cap["uniformAndStorageBuffer16BitAccess"] = True - cap["storageBuffer8BitAccess"] = True - cap["storagePushConstant8"] = True - cap["uniformAndStorageBuffer8BitAccess"] = True - cap["variablePointers"] = True - cap["variablePointersStorageBuffer"] = True - - elif arch in ["ampere", "turing"]: - cap["max_compute_shared_memory_size"] = 49152 - cap["max_compute_workgroup_invocations"] = 1024 - cap["max_compute_workgroup_size"] = [1024, 1024, 1024] - - cap["subgroup_size"] = 32 - cap["min_subgroup_size"] = 32 - cap["max_subgroup_size"] = 32 - cap["subgroupFeatures"] = [ - "Basic", - "Vote", - "Arithmetic", - "Ballot", - "Shuffle", - "ShuffleRelative", - "Clustered", - "Quad", - ] - - cap["shaderFloat16"] = True - cap["shaderFloat64"] = True - cap["shaderInt8"] = True - cap["shaderInt16"] = True - cap["shaderInt64"] = True - cap["shaderIntegerDotProduct"] = True - cap["storageBuffer16BitAccess"] = True - cap["storagePushConstant16"] = True - cap["uniformAndStorageBuffer16BitAccess"] = True - cap["storageBuffer8BitAccess"] = True - cap["storagePushConstant8"] = True - cap["uniformAndStorageBuffer8BitAccess"] = True - cap["variablePointers"] = True - cap["variablePointersStorageBuffer"] = True - - cap["coopmatCases"] = [ - "mSize = 8, nSize = 8, kSize = 32, aType = i8, bType = i8, cType = i32, resultType = i32, accSat = false, scope = #vk.scope", - "mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, accSat = false, scope = #vk.scope", - "mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f32, resultType = f32, accSat = false, scope = #vk.scope", - ] - - elif arch == "adreno": - cap["max_compute_shared_memory_size"] = 32768 - cap["max_compute_workgroup_invocations"] = 1024 - cap["max_compute_workgroup_size"] = [1024, 1024, 64] - - cap["subgroup_size"] = 64 - cap["subgroupFeatures"] = [ - "Basic", - "Vote", - "Arithmetic", - "Ballot", - "Shuffle", - "ShuffleRelative", - "Quad", - ] - - cap["shaderFloat16"] = True - cap["shaderInt8"] = True - cap["shaderInt16"] = True - - cap["storageBuffer16BitAccess"] = True - if os == "android31": - cap["uniformAndStorageBuffer8BitAccess"] = True - - cap["variablePointers"] = True - cap["variablePointersStorageBuffer"] = True - - elif arch == "unknown": - cap["subgroup_size"] = 64 - cap["variablePointers"] = False - cap["variablePointersStorageBuffer"] = False - else: - print( - f"Architecture {arch} not matched. Using default vulkan target device capability" - ) - - def get_comma_sep_str(ele_list): - l = "" - for ele in ele_list: - l += f"{ele}, " - l = f"[{l[:-2]}]" - return l - - res = "" - for k, v in cap.items(): - if v is None or v == False: - continue - if isinstance(v, bool): - res += f"{k} = {'unit' if v == True else None}, " - elif isinstance(v, list): - if k == "subgroupFeatures": - res += f"subgroup_features = {get_subgroup_val(v)}: i32, " - elif k == "max_compute_workgroup_size": - res += f"max_compute_workgroup_size = dense<{get_comma_sep_str(v)}>: vector<{len(v)}xi32>, " - elif k == "coopmatCases": - cmc = "" - for case in v: - cmc += f"#spirv.coop_matrix_props_khr<{case}>, " - res += f"cooperative_matrix_properties_khr = [{cmc[:-2]}], " - else: - res += f"{k} = {get_comma_sep_str(v)}, " - else: - res += f"{k} = {v}, " - res = res[:-2] - return res diff --git a/shark/iree_utils/vulkan_utils.py b/shark/iree_utils/vulkan_utils.py deleted file mode 100644 index 96ad33602a..0000000000 --- a/shark/iree_utils/vulkan_utils.py +++ /dev/null @@ -1,221 +0,0 @@ -# Copyright 2020 The Nod Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# All the iree_vulkan related functionalities go here. - -import functools -from os import linesep -from shark.iree_utils._common import run_cmd -import iree.runtime as ireert -from sys import platform -from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag -from shark.parser import shark_args - - -@functools.cache -def get_all_vulkan_devices(): - from iree.runtime import get_driver - - try: - driver = get_driver("vulkan") - device_list_src = driver.query_available_devices() - except: - device_list_src = {} - - return [d["name"] for d in device_list_src] - - -@functools.cache -def get_vulkan_device_name(device_num=0): - if isinstance(device_num, int): - vulkaninfo_list = get_all_vulkan_devices() - - if len(vulkaninfo_list) == 0: - raise ValueError("No device name found in VulkanInfo!") - if len(vulkaninfo_list) > 1: - print("Following devices found:") - for i, dname in enumerate(vulkaninfo_list): - print(f"{i}. {dname}") - print(f"Choosing device: vulkan://{device_num}") - vulkan_device_name = vulkaninfo_list[device_num] - else: - from iree.runtime import get_driver - - vulkan_device_driver = get_driver(device_num) - vulkan_device_name = vulkan_device_driver.query_available_devices()[0] - print(vulkan_device_name) - return vulkan_device_name - - -def get_os_name(): - if platform.startswith("linux"): - return "linux" - elif platform == "darwin": - return "macos" - elif platform == "win32": - return "windows" - else: - print("Cannot detect OS type, defaulting to linux.") - return "linux" - - -@functools.cache -def get_vulkan_target_triple(device_name): - """This method provides a target triple str for specified vulkan device. - - Args: - device_name (str): name of the hardware device to be used with vulkan - - Returns: - str or None: target triple or None if no match found for given name - """ - - # TODO: Replace this with a dict or something smarter. - system_os = get_os_name() - # Apple Targets - if all(x in device_name for x in ("Apple", "M1")): - triple = "m1-moltenvk-macos" - elif all(x in device_name for x in ("Apple", "M2")): - triple = "m1-moltenvk-macos" - - # Nvidia Targets - elif all(x in device_name for x in ("RTX", "2080")): - triple = f"turing-rtx2080-{system_os}" - elif all(x in device_name for x in ("A100", "SXM4")): - triple = f"ampere-a100-{system_os}" - elif all(x in device_name for x in ("RTX", "3090")): - triple = f"ampere-rtx3090-{system_os}" - elif all(x in device_name for x in ("RTX", "3080")): - triple = f"ampere-rtx3080-{system_os}" - elif all(x in device_name for x in ("RTX", "3070")): - triple = f"ampere-rtx3070-{system_os}" - elif all(x in device_name for x in ("RTX", "3060")): - triple = f"ampere-rtx3060-{system_os}" - elif all(x in device_name for x in ("RTX", "3050")): - triple = f"ampere-rtx3050-{system_os}" - # We use ampere until lovelace target triples are plumbed in. - elif all(x in device_name for x in ("RTX", "4090")): - triple = f"ampere-rtx4090-{system_os}" - elif all(x in device_name for x in ("RTX", "4080")): - triple = f"ampere-rtx4080-{system_os}" - elif all(x in device_name for x in ("RTX", "4070")): - triple = f"ampere-rtx4070-{system_os}" - elif all(x in device_name for x in ("RTX", "4000")): - triple = f"turing-rtx4000-{system_os}" - elif all(x in device_name for x in ("RTX", "5000")): - triple = f"turing-rtx5000-{system_os}" - elif all(x in device_name for x in ("RTX", "6000")): - triple = f"turing-rtx6000-{system_os}" - elif all(x in device_name for x in ("RTX", "8000")): - triple = f"turing-rtx8000-{system_os}" - elif all(x in device_name for x in ("TITAN", "RTX")): - triple = f"turing-titanrtx-{system_os}" - elif all(x in device_name for x in ("GTX", "1060")): - triple = f"pascal-gtx1060-{system_os}" - elif all(x in device_name for x in ("GTX", "1070")): - triple = f"pascal-gtx1070-{system_os}" - elif all(x in device_name for x in ("GTX", "1080")): - triple = f"pascal-gtx1080-{system_os}" - - # Amd Targets - # Linux: Radeon RX 7900 XTX - # Windows: AMD Radeon RX 7900 XTX - elif all(x in device_name for x in ("RX", "7800")): - triple = f"rdna3-7800-{system_os}" - elif all(x in device_name for x in ("RX", "7900")): - triple = f"rdna3-7900-{system_os}" - elif all(x in device_name for x in ("Radeon", "780M")): - triple = f"rdna3-780m-{system_os}" - elif all(x in device_name for x in ("AMD", "PRO", "W7900")): - triple = f"rdna3-w7900-{system_os}" - elif any(x in device_name for x in ("AMD", "Radeon")): - triple = f"rdna2-unknown-{system_os}" - # Intel Targets - elif any(x in device_name for x in ("A770", "A750")): - triple = f"arc-770-{system_os}" - elif "v620" in device_name: - triple = f"rdna2-v620-{system_os}" - - # Adreno Targets - elif all(x in device_name for x in ("Adreno", "740")): - triple = f"adreno-a740-{system_os}" - - else: - triple = None - return triple - - -def get_vulkan_triple_flag(device_name="", device_num=0, extra_args=[]): - for flag in extra_args: - if "-iree-vulkan-target-triple=" in flag: - print(f"Using target triple {flag.split('=')[1]}") - return None - - if device_name == "" or device_name == [] or device_name is None: - vulkan_device = get_vulkan_device_name(device_num=device_num) - else: - vulkan_device = device_name - triple = get_vulkan_target_triple(vulkan_device) - if triple is not None: - print( - f"Found vulkan device {vulkan_device}. Using target triple {triple}" - ) - return f"--iree-vulkan-target-triple={triple}" - print( - """Optimized kernel for your target device is not added yet. - Contact SHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u] - or pull up an issue.""" - ) - print(f"Target : {vulkan_device}") - return None - - -def get_iree_vulkan_args(device_num=0, extra_args=[]): - # res_vulkan_flag = ["--iree-flow-demote-i64-to-i32"] - - res_vulkan_flag = [] - res_vulkan_flag += [ - "--iree-stream-resource-max-allocation-size=3221225472", - "--iree-flow-inline-constants-max-byte-length=0" - ] - vulkan_triple_flag = None - for arg in extra_args: - if "-iree-vulkan-target-triple=" in arg: - print(f"Using target triple {arg} from command line args") - vulkan_triple_flag = arg - break - - if vulkan_triple_flag is None: - vulkan_triple_flag = get_vulkan_triple_flag( - device_num=device_num, extra_args=extra_args - ) - res_vulkan_flag += [vulkan_triple_flag] - - return res_vulkan_flag - - -@functools.cache -def get_iree_vulkan_runtime_flags(): - vulkan_runtime_flags = [ - f"--vulkan_validation_layers={'true' if shark_args.vulkan_debug_utils else 'false'}", - f"--vulkan_debug_verbosity={'4' if shark_args.vulkan_debug_utils else '0'}" - f"--vulkan-robust-buffer-access={'true' if shark_args.vulkan_debug_utils else 'false'}", - ] - return vulkan_runtime_flags - - -def set_iree_vulkan_runtime_flags(flags): - for flag in flags: - ireert.flags.parse_flags(flag) - return diff --git a/shark/model_annotation.py b/shark/model_annotation.py deleted file mode 100644 index 4126a3add1..0000000000 --- a/shark/model_annotation.py +++ /dev/null @@ -1,468 +0,0 @@ -# Copyright 2020 The Nod Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: -This function takes the model mlir file and the tuned config file as input, -and output a new mlir file with lowering configs annotated on certain ops. -There are two ways to utilize the function: -1. Call model_annotation function within another python script -from shark.model_annotation import model_annotation -with create_context() as ctx: - module = model_annotation(ctx, input_contents=..., config_path=..., search_op=...) -2. Run model_annotation.py directly -python model_annotation.py -model path_to_original_mlir -config_path path_to_config_file -""" - -import json -import os -import sys -from typing import Dict, List - -import iree.compiler._mlir_libs -from iree.compiler import ir - - -def model_annotation( - ctx: ir.Context, - *, - input_contents: str, - config_path: str, - search_op: str, - winograd: bool = False, -): - if os.path.isfile(input_contents): - with open(input_contents, "rb") as f: - input_contents = f.read() - module = ir.Module.parse(input_contents) - - if config_path == "": - return module - - if winograd: - with open(config_path, "r") as f: - data = json.load(f) - configs = data["c,f"] - else: - configs = load_model_configs(config_path) - - # The Python API does not expose a general walk() function, so we just - # do it ourselves. - walk_children(module.operation, configs, search_op, winograd) - - if not module.operation.verify(): - raise RuntimeError("Modified program does not verify!") - - return module - - -def load_model_configs(config_path: str): - config = {} - with open(config_path, "r") as f: - for line in f: - data = json.loads(line) - - if "identifier" not in data.keys(): - continue - if data["identifier"] == "matmul": - matrix_size = [data["m"], data["n"], data["k"]] - elif data["identifier"] == "bmm": - matrix_size = [data["b"], data["m"], data["n"], data["k"]] - elif data["identifier"] == "generic": - matrix_size = [1, data["b"], data["m"], data["n"], data["k"]] - elif data["identifier"] == "conv": - matrix_size = [ - data["n"], - data["ih"], - data["iw"], - data["c"], - data["kh"], - data["kw"], - data["f"], - data["oh"], - data["ow"], - data["d"], - data["s"], - data["p"], - ] - config[shape_list_to_string(matrix_size)] = data - f.close() - return config - - -def walk_children( - op: ir.Operation, configs: List[Dict], search_op: str, winograd: bool -): - if search_op == "matmul": - op_names = ["linalg.matmul", "mhlo.dot"] - elif search_op == "bmm": - op_names = ["linalg.batch_matmul", "mhlo.dot_general"] - elif search_op == "conv": - op_names = ["mhlo.convolution", "linalg.conv_2d_nhwc_hwcf"] - elif search_op == "generic": - op_names = ["linalg.generic"] - elif search_op == "all": - op_names = [ - "mhlo.dot", - "mhlo.dot_general", - "mhlo.convolution", - "linalg.matmul", - "linalg.batch_matmul", - "linalg.conv_2d_nhwc_hwcf", - "linalg.generic", - ] - else: - raise ValueError(f"{search_op} op is not tunable.") - - for region in op.regions: - for block in region.blocks: - for child_op in block.operations: - # TODO: This is dumb. Both Operation and OpView should expose - # 'operation' and 'name' attributes. - if isinstance(child_op, ir.OpView): - child_op = child_op.operation - if winograd and child_op.name in [ - "linalg.conv_2d_nchw_fchw", - "linalg.conv_2d_nhwc_hwcf", - ]: - add_winograd_attribute(child_op, configs) - if child_op.name in op_names: - if child_op.name == "linalg.generic": - # This is for generic op that has contractionOpInterface - # which is basically einsum("mk,bkn->bmn") - op_result = str(child_op.results[0]) - op_iterator = str( - child_op.attributes["iterator_types"] - ) - if len(child_op.operands) != 3: - continue - if "reduction" not in op_iterator: - continue - if ( - "arith.addf" not in op_result - or "arith.mulf" not in op_result - ): - continue - if "arith.subf" in op_result: - continue - - child_op_shape = get_op_shape(child_op, search_op) - if ( - child_op_shape in configs.keys() - and configs[child_op_shape]["options"][0] != None - ): - add_attributes( - child_op, configs[child_op_shape]["options"][0] - ) - - walk_children(child_op, configs, search_op, winograd) - - -def get_op_shape(op: ir.Operation, search_op: str): - shape_list = [] - if search_op in ["generic", "all"]: - if op.name in ["linalg.generic"]: - input1 = str(op.operands[0].type) - input2 = str(op.operands[1].type) - m = input1.split("tensor<")[1].split("x")[0] - b = input2.split("tensor<")[1].split("x")[0] - k = input2.split("tensor<")[1].split("x")[1] - n = input2.split("tensor<")[1].split("x")[2] - shape_list = [1, int(b), int(m), int(n), int(k)] - - if search_op in ["matmul", "all"]: - if op.name in ["mhlo.dot"]: - op_result = str(op.results[0]) - m = op_result.split("tensor<")[1].split("x")[0] - k = op_result.split("tensor<")[1].split("x")[1] - n = op_result.split("tensor<")[2].split("x")[1] - shape_list = [int(m), int(n), int(k)] - elif op.name in ["linalg.matmul"]: - op_result = str(op.results[0]).split("ins(")[1] - m = op_result.split("tensor<")[1].split("x")[0] - k = op_result.split("tensor<")[1].split("x")[1] - n = op_result.split("tensor<")[2].split("x")[1] - shape_list = [int(m), int(n), int(k)] - - if search_op in ["bmm", "all"]: - if op.name in ["mhlo.dot_general"]: - op_result = str(op.results[0]) - b = op_result.split("tensor<")[1].split("x")[1] - m = op_result.split("tensor<")[1].split("x")[2] - k = op_result.split("tensor<")[1].split("x")[3] - n = op_result.split("tensor<")[3].split("x")[3] - shape_list = [int(b), int(m), int(n), int(k)] - elif op.name in ["linalg.batch_matmul"]: - op_result = str(op.results[0]).split("ins(")[1] - b = op_result.split("tensor<")[1].split("x")[0] - m = op_result.split("tensor<")[1].split("x")[1] - k = op_result.split("tensor<")[1].split("x")[2] - n = op_result.split("tensor<")[3].split("x")[2] - shape_list = [int(b), int(m), int(n), int(k)] - - if search_op in ["conv", "all"]: - if op.name in ["mhlo.convolution"]: - op_result = str(op.results[0]) - dilation = ( - str(op.attributes["rhs_dilation"]) - .split("dense<")[1] - .split(">")[0] - ) - stride = ( - str(op.attributes["window_strides"]) - .split("dense<")[1] - .split(">")[0] - ) - pad = ( - str(op.attributes["padding"]).split("dense<")[1].split(">")[0] - ) - n = op_result.split("tensor<")[1].split("x")[0] - ih = op_result.split("tensor<")[1].split("x")[1] - iw = op_result.split("tensor<")[1].split("x")[2] - c = op_result.split("tensor<")[1].split("x")[3] - kh = op_result.split("tensor<")[2].split("x")[0] - kw = op_result.split("tensor<")[2].split("x")[1] - f = op_result.split("tensor<")[2].split("x")[3] - oh = op_result.split("tensor<")[3].split("x")[1] - ow = op_result.split("tensor<")[3].split("x")[2] - shape_list = [ - int(n), - int(ih), - int(iw), - int(c), - int(kh), - int(kw), - int(f), - int(oh), - int(ow), - int(dilation), - int(stride), - int(pad), - ] - - elif op.name in ["linalg.conv_2d_nhwc_hwcf"]: - op_result = str(op.results[0]).split("ins(")[1] - dilation = ( - str(op.attributes["dilations"]) - .split("dense<")[1] - .split(">")[0] - ) - stride = ( - str(op.attributes["strides"]).split("dense<")[1].split(">")[0] - ) - pad = 0 - n = op_result.split("tensor<")[1].split("x")[0] - ih = op_result.split("tensor<")[1].split("x")[1] - iw = op_result.split("tensor<")[1].split("x")[2] - c = op_result.split("tensor<")[1].split("x")[3] - kh = op_result.split("tensor<")[2].split("x")[0] - kw = op_result.split("tensor<")[2].split("x")[1] - f = op_result.split("tensor<")[2].split("x")[3] - oh = op_result.split("tensor<")[3].split("x")[1] - ow = op_result.split("tensor<")[3].split("x")[2] - shape_list = [ - int(n), - int(ih), - int(iw), - int(c), - int(kh), - int(kw), - int(f), - int(oh), - int(ow), - int(dilation), - int(stride), - int(pad), - ] - - shape_str = shape_list_to_string(shape_list) - return shape_str - - -def add_attributes(op: ir.Operation, config: List[Dict]): - # Parse the config file - split_k = None - pipeline_depth = None - store_stage = None - subgroup_size = None - - if "GPU" in config["pipeline"]: - pipeline = ( - "LLVMGPUMatmulSimt" - if config["pipeline"] == "GPU" - else "LLVMGPUMatmulTensorCore" - ) - tile_sizes = [config["work_group_tile_sizes"]] - workgroup_size = config["work_group_sizes"] - if "pipeline_depth" in config.keys(): - pipeline_depth = config["pipeline_depth"] - if "split_k" in config.keys(): - split_k = config["split_k"] - elif "SPIRV" in config["pipeline"]: - pipeline = config["pipeline"] - if pipeline == "SPIRVMatmulPromoteVectorize": - tile_sizes = [ - config["work_group_tile_sizes"] - + [config["reduction_tile_sizes"][-1]], - ] - else: - tile_sizes = [ - config["work_group_tile_sizes"], - config["parallel_tile_sizes"], - config["reduction_tile_sizes"], - ] - - workgroup_size = config["work_group_sizes"] - if "vector_tile_sizes" in config.keys(): - tile_sizes += [config["vector_tile_sizes"]] - if "window_tile_sizes" in config.keys(): - tile_sizes += [config["window_tile_sizes"]] - if "subgroup_size" in config.keys(): - subgroup_size = config["subgroup_size"] - if "pipeline_depth" in config.keys(): - pipeline_depth = config["pipeline_depth"] - if "store_stage" in config.keys(): - store_stage = config["store_stage"] - else: - # For IREE CPU pipelines - pipeline = config["pipeline"] - tile_sizes = [ - config["work_group_tile_sizes"], - config["parallel_tile_sizes"], - config["reduction_tile_sizes"], - ] - workgroup_size = [] - - # Add compilation info as an attribute. We don't have a Python binding for CompilationInfo, - # so we just parse its string form. - if pipeline_depth != None: - translation_info = f"{pipeline} pipeline_depth = {pipeline_depth}" - if store_stage != None: - translation_info += f" store_stage = {store_stage}" - else: - translation_info = f"{pipeline}" - - compilation_info = ( - f"#iree_codegen.compilation_info<" - f"lowering_config = , " - f"translation_info = <{translation_info}>, " - f"workgroup_size = {repr(workgroup_size)} " - ) - - if subgroup_size != None: - compilation_info += f", subgroup_size = {subgroup_size}>" - else: - compilation_info += ">" - - attr = ir.Attribute.parse(compilation_info) - op.attributes["compilation_info"] = attr - - # Add other attributes if required. - if split_k: - add_attribute_by_name(op, "iree_flow_split_k", split_k) - - -def add_winograd_attribute(op: ir.Operation, config: List): - op_result = str(op.results[0]).split("ins(")[1] - dilation = int( - str(op.attributes["dilations"]).split("dense<")[1].split(">")[0] - ) - stride = int( - str(op.attributes["strides"]).split("dense<")[1].split(">")[0] - ) - - if op.name == "linalg.conv_2d_nchw_fchw": - f = int(op_result.split("tensor<")[2].split("x")[0]) - c = int(op_result.split("tensor<")[2].split("x")[1]) - kh = int(op_result.split("tensor<")[2].split("x")[2]) - kw = int(op_result.split("tensor<")[2].split("x")[3]) - else: - kh = int(op_result.split("tensor<")[2].split("x")[0]) - kw = int(op_result.split("tensor<")[2].split("x")[1]) - c = int(op_result.split("tensor<")[2].split("x")[2]) - f = int(op_result.split("tensor<")[2].split("x")[3]) - - if ( - dilation == 1 - and stride == 1 - and kh == 3 - and kw == 3 - and [c, f] in config - ): - op.attributes["iree_winograd_conv"] = ir.IntegerAttr.get( - ir.IntegerType.get_signless(64), 1 - ) - - -def add_attribute_by_name(op: ir.Operation, name: str, val: int): - attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), val) - op.attributes[name] = attr - - -def shape_list_to_string(input): - return "x".join([str(d) for d in input]) - - -def create_context() -> ir.Context: - context = ir.Context() - context.allow_unregistered_dialects = True - return context - - -if __name__ == "__main__": - import argparse - from pathlib import Path - - def path_expand(s): - return Path(s).expanduser().resolve() - - parser = argparse.ArgumentParser() - parser.add_argument( - "-model", - type=path_expand, - default="model.mlir", - help="Path to the input mlir file", - ) - parser.add_argument( - "-config_path", - type=path_expand, - default="best_configs.json", - help="Path where stores the op config file", - ) - parser.add_argument( - "-output_path", - type=path_expand, - default="tuned_model.mlir", - help="Path to save the annotated mlir file", - ) - parser.add_argument( - "-search_op", - type=str, - default="all", - help="Op to be optimized. options are matmul, bmm, conv.", - ) - - args = parser.parse_args() - - with create_context() as ctx: - module = model_annotation( - ctx, - input_contents=args.model, - config_path=args.config_path, - search_op=args.search_op, - ) - mlir_str = str(module) - with open(args.output_path, "w") as f: - f.write(mlir_str) - print(f"Saved mlir in {args.output_path}.") diff --git a/shark/parser.py b/shark/parser.py deleted file mode 100644 index 008c42f014..0000000000 --- a/shark/parser.py +++ /dev/null @@ -1,170 +0,0 @@ -# Copyright 2020 The Nod Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import os -import shlex -import subprocess - - -class SplitStrToListAction(argparse.Action): - def __init__(self, option_strings, dest, *args, **kwargs): - super(SplitStrToListAction, self).__init__( - option_strings=option_strings, dest=dest, *args, **kwargs - ) - - def __call__(self, parser, namespace, values, option_string=None): - del parser, option_string - setattr(namespace, self.dest, shlex.split(" ")) - - -parser = argparse.ArgumentParser(description="SHARK runner.") - -parser.add_argument( - "--device", - type=str, - default="cpu", - help="Device on which shark_runner runs. options are cpu, cuda, and vulkan", -) -parser.add_argument( - "--additional_compile_args", - default=list(), - nargs=1, - action=SplitStrToListAction, - help="Additional arguments to pass to the compiler. These are appended as the last arguments.", -) -parser.add_argument( - "--additional_runtime_args", - default=list(), - nargs=1, - action=SplitStrToListAction, - help="Additional arguments to pass to the IREE runtime. These are appended as the last arguments.", -) -parser.add_argument( - "--enable_tf32", - type=bool, - default=False, - help="Enables TF32 precision calculations on supported GPUs.", -) -parser.add_argument( - "--model_config_path", - help="Directory to where the tuned model config file is located.", - default=None, -) - -parser.add_argument( - "--num_warmup_iterations", - type=int, - default=5, - help="Run the model for the specified number of warmup iterations.", -) -parser.add_argument( - "--num_iterations", - type=int, - default=100, - help="Run the model for the specified number of iterations.", -) -parser.add_argument( - "--onnx_bench", - default=False, - action="store_true", - help="When enabled, pytest bench results will include ONNX benchmark results.", -) -parser.add_argument( - "--shark_prefix", - default=None, - help="gs://shark_tank//model_directories", -) -parser.add_argument( - "--update_tank", - default=True, - action="store_true", - help="When enabled, SHARK downloader will update local shark_tank if local hash is different from latest upstream hash.", -) -parser.add_argument( - "--force_update_tank", - default=False, - action="store_true", - help="When enabled, SHARK downloader will force an update of local shark_tank artifacts for each request.", -) -parser.add_argument( - "--local_tank_cache", - default=None, - help="Specify where to save downloaded shark_tank artifacts. If this is not set, the default is ~/.local/shark_tank/.", -) - -parser.add_argument( - "--dispatch_benchmarks", - default=None, - help='dispatches to return benchamrk data on. use "All" for all, and None for none.', -) - -parser.add_argument( - "--dispatch_benchmarks_dir", - default="temp_dispatch_benchmarks", - help='directory where you want to store dispatch data generated with "--dispatch_benchmarks"', -) - -parser.add_argument( - "--enable_conv_transform", - default=False, - action="store_true", - help="Enables the --iree-flow-enable-conv-nchw-to-nhwc-transform flag.", -) - -parser.add_argument( - "--enable_img2col_transform", - default=False, - action="store_true", - help="Enables the --iree-flow-enable-conv-img2col-transform flag.", -) - -parser.add_argument( - "--use_winograd", - default=False, - action="store_true", - help="Enables the --iree-flow-enable-conv-winograd-transform flag.", -) - -parser.add_argument( - "--device_allocator", - type=str, - nargs="*", - default=["caching"], - help="Specifies one or more HAL device allocator specs " - "to augment the base device allocator", - choices=["debug", "caching"], -) -parser.add_argument( - "--task_topology_max_group_count", - type=str, - default=None, - help="passthrough flag for the iree flag of the same name. If None, defaults to cpu-count", -) - -parser.add_argument( - "--vulkan_debug_utils", - default=False, - action=argparse.BooleanOptionalAction, - help="Profiles vulkan device and collects the .rdc info.", -) - -parser.add_argument( - "--vulkan_validation_layers", - default=False, - action=argparse.BooleanOptionalAction, - help="Flag for disabling vulkan validation layers when benchmarking.", -) - -shark_args, unknown = parser.parse_known_args() diff --git a/shark/shark_benchmark_runner.py b/shark/shark_benchmark_runner.py deleted file mode 100644 index 4d87d78867..0000000000 --- a/shark/shark_benchmark_runner.py +++ /dev/null @@ -1,501 +0,0 @@ -# Copyright 2020 The Nod Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from shark.shark_runner import SharkRunner -from shark.iree_utils.compile_utils import ( - export_iree_module_to_vmfb, - load_flatbuffer, - get_iree_runtime_config, -) -from shark.iree_utils.benchmark_utils import ( - build_benchmark_args, - run_benchmark_module, -) -from shark.parser import shark_args -from datetime import datetime -import time -from typing import Optional -import csv -import os - -TF_CPU_DEVICE = "/CPU:0" -TF_GPU_DEVICE = "/GPU:0" - - -def _bytes_to_mb_str(bytes_: Optional[int]) -> str: - return "" if bytes_ is None else f"{bytes_ / 1e6:.6f}" - - -class OnnxFusionOptions(object): - def __init__(self): - self.disable_gelu = False - self.disable_layer_norm = False - self.disable_attention = False - self.disable_skip_layer_norm = False - self.disable_embed_layer_norm = False - self.disable_bias_skip_layer_norm = False - self.disable_bias_gelu = False - self.enable_gelu_approximation = False - self.use_mask_index = False - self.no_attention_mask = False - - -def check_requirements(frontend): - import importlib - - has_pkgs = False - if frontend == "torch": - tv_spec = importlib.util.find_spec("torchvision") - has_pkgs = tv_spec is not None - - elif frontend in ["tensorflow", "tf"]: - keras_spec = importlib.util.find_spec("keras") - tf_spec = importlib.util.find_spec("tensorflow") - has_pkgs = keras_spec is not None and tf_spec is not None - - return has_pkgs - - -class SharkBenchmarkRunner(SharkRunner): - # SharkRunner derived class with Benchmarking capabilities. - def __init__( - self, - mlir_module: bytes, - device: str = "none", - mlir_dialect: str = "linalg", - extra_args: list = [], - ): - self.device = shark_args.device if device == "none" else device - self.enable_tf32 = shark_args.enable_tf32 - self.frontend_model = None - self.vmfb_file = None - self.mlir_dialect = mlir_dialect - self.extra_args = extra_args - self.import_args = {} - self.temp_file_to_unlink = None - if not os.path.isfile(mlir_module): - print( - "Warning: Initializing SharkRunner with a mlir string/bytecode object will duplicate the model in RAM at compile time. To avoid this, initialize SharkInference with a path to a MLIR module on your hard disk instead." - ) - self.compile_str = True - else: - self.compile_str = False - SharkRunner.__init__( - self, - mlir_module, - device, - self.mlir_dialect, - self.extra_args, - compile_vmfb=False, - ) - self.vmfb_file = export_iree_module_to_vmfb( - mlir_module, - device, - ".", - self.mlir_dialect, - extra_args=self.extra_args, - compile_str=self.compile_str, - ) - params = load_flatbuffer( - self.vmfb_file, - device, - mmap=True, - ) - self.iree_compilation_module = params["vmfb"] - self.iree_config = params["config"] - self.temp_file_to_unlink = params["temp_file_to_unlink"] - del params - - def setup_cl(self, input_tensors): - self.benchmark_cl = build_benchmark_args( - self.vmfb_file, - self.device, - input_tensors, - mlir_dialect=self.mlir_dialect, - ) - - def benchmark_frontend(self, modelname): - if self.mlir_dialect in ["linalg", "torch"]: - return self.benchmark_torch(modelname) - - elif self.mlir_dialect in ["mhlo", "tf"]: - return self.benchmark_tf(modelname) - - def benchmark_torch(self, modelname, device="cpu"): - import torch - from tank.model_utils import get_torch_model - - # TODO: Pass this as an arg. currently the best way is to setup with BENCHMARK=1 if we want to use torch+cuda, else use cpu. - device = "cuda" if torch.cuda.is_available() else "cpu" - if device == "cuda": - torch.set_default_device("cuda:0") - # if self.enable_tf32: - # torch.backends.cuda.matmul.allow_tf32 = True - else: - torch.set_default_dtype(torch.float32) - torch.set_default_device("cpu") - torch_device = torch.device("cuda:0" if device == "cuda" else "cpu") - HFmodel, input = get_torch_model(modelname, self.import_args)[:2] - frontend_model = HFmodel.model - frontend_model.to(torch_device) - if device == "cuda": - frontend_model.cuda() - input.to(torch.device("cuda:0")) - print(input) - else: - frontend_model.cpu() - input.cpu() - - for i in range(shark_args.num_warmup_iterations): - frontend_model.forward(input) - - if device == "cuda": - torch.cuda.reset_peak_memory_stats() - begin = time.time() - for i in range(shark_args.num_iterations): - out = frontend_model.forward(input) - end = time.time() - if device == "cuda": - stats = torch.cuda.memory_stats() - device_peak_b = stats["allocated_bytes.all.peak"] - frontend_model.to(torch.device("cpu")) - input.to(torch.device("cpu")) - torch.cuda.empty_cache() - else: - device_peak_b = None - - print( - f"Torch benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}" - ) - if device == "cuda": - # Set device to CPU so we don't run into segfaults exiting pytest subprocesses. - torch_device = torch.device("cpu") - return [ - f"{shark_args.num_iterations/(end-begin)}", - f"{((end-begin)/shark_args.num_iterations)*1000}", - "", # host_peak_b (CPU usage) is not reported by PyTorch. - _bytes_to_mb_str(device_peak_b), - ] - - def benchmark_tf(self, modelname): - import os - - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" - import tensorflow as tf - - visible_default = tf.config.list_physical_devices("GPU") - try: - tf.config.set_visible_devices([], "GPU") - visible_devices = tf.config.get_visible_devices() - for device in visible_devices: - assert device.device_type != "GPU" - except: - # Invalid device or cannot modify virtual devices once initialized. - pass - - from tank.model_utils_tf import get_tf_model - - # tf_device = TF_GPU_DEVICE if self.device == "cuda" else TF_CPU_DEVICE - tf_device = TF_CPU_DEVICE - with tf.device(tf_device): - ( - model, - input, - ) = get_tf_model( - modelname, self.import_args - )[:2] - frontend_model = model - - for i in range(shark_args.num_warmup_iterations): - frontend_model.forward(*input) - - if tf_device == TF_GPU_DEVICE: - tf.config.experimental.reset_memory_stats(tf_device) - begin = time.time() - for i in range(shark_args.num_iterations): - out = frontend_model.forward(*input) - end = time.time() - if tf_device == TF_GPU_DEVICE: - memory_info = tf.config.experimental.get_memory_info(tf_device) - device_peak_b = memory_info["peak"] - else: - # tf.config.experimental does not currently support measuring - # CPU memory usage. - device_peak_b = None - - print( - f"TF benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}" - ) - return [ - f"{shark_args.num_iterations/(end-begin)}", - f"{((end-begin)/shark_args.num_iterations)*1000}", - "", # host_peak_b (CPU usage) is not reported by TensorFlow. - _bytes_to_mb_str(device_peak_b), - ] - - def benchmark_c(self): - iter_per_second, host_peak_b, device_peak_b = run_benchmark_module( - self.benchmark_cl - ) - print(f"Shark-IREE-C benchmark:{iter_per_second} iter/second") - return [ - f"{iter_per_second}", - f"{1000/iter_per_second}", - _bytes_to_mb_str(host_peak_b), - _bytes_to_mb_str(device_peak_b), - ] - - def benchmark_python(self, inputs): - input_list = [x for x in inputs] - for i in range(shark_args.num_warmup_iterations): - self.run("forward", input_list) - - begin = time.time() - for i in range(shark_args.num_iterations): - out = self.run("forward", input_list) - end = time.time() - print( - f"Shark-IREE Python benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}" - ) - return [ - f"{shark_args.num_iterations/(end-begin)}", - f"{((end-begin)/shark_args.num_iterations)*1000}", - ] - - def benchmark_onnx(self, modelname, inputs): - if self.device == "cuda": - print( - "Currently GPU benchmarking on ONNX is not supported in SHARK." - ) - return ["N/A", "N/A"] - else: - from onnxruntime.transformers.benchmark import run_onnxruntime - from onnxruntime.transformers.huggingface_models import MODELS - from onnxruntime.transformers.benchmark_helper import ( - ConfigModifier, - Precision, - ) - import psutil - - if modelname == "microsoft/MiniLM-L12-H384-uncased": - modelname = "bert-base-uncased" - if modelname not in MODELS: - print( - f"{modelname} is currently not supported in ORT's HF. Check \ -https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/python/tools/transformers/huggingface_models.py \ -for currently supported models. Exiting benchmark ONNX." - ) - return ["N/A", "N/A"] - use_gpu = self.device == "cuda" - num_threads = psutil.cpu_count(logical=False) - batch_sizes = [1] - sequence_lengths = [128] - cache_dir = os.path.join(".", "cache_models") - onnx_dir = os.path.join(".", "onnx_models") - verbose = False - input_counts = [1] - optimize_onnx = True - validate_onnx = False - disable_ort_io_binding = False - use_raw_attention_mask = True - model_fusion_statistics = {} - overwrite = False - model_source = "pt" # Either "pt" or "tf" - provider = None - config_modifier = ConfigModifier(None) - onnx_args = OnnxFusionOptions() - result = run_onnxruntime( - use_gpu, - provider, - (modelname,), - None, - config_modifier, - Precision.FLOAT32, - num_threads, - batch_sizes, - sequence_lengths, - shark_args.num_iterations, - input_counts, - optimize_onnx, - validate_onnx, - cache_dir, - onnx_dir, - verbose, - overwrite, - disable_ort_io_binding, - use_raw_attention_mask, - model_fusion_statistics, - model_source, - onnx_args, - ) - print( - f"ONNX ORT-benchmark:{result[0]['QPS']} iter/second, Total Iterations:{shark_args.num_iterations}" - ) - return [ - result[0]["QPS"], - result[0]["average_latency_ms"], - ] - - def get_metadata(self, modelname): - metadata_path = os.path.join(".", "tank", "model_metadata.csv") - with open(metadata_path, mode="r") as csvfile: - torch_reader = csv.reader(csvfile, delimiter=",") - fields = next(torch_reader) - for row in torch_reader: - torch_model_name = row[0] - if torch_model_name == modelname: - param_count = row[3] - model_tags = row[4] - model_notes = row[5] - return [param_count, model_tags, model_notes] - - def compare_bench_results(self, baseline: str, result: str): - if baseline is not None: - # Takes a baseline and a result string and calculates a comparison, e.g. "1.04x baseline". - a = float(baseline) - b = float(result) - comparison = a / b - comp_str = f"{round(comparison, 2)}x baseline" - else: - comp_str = "N/A" - - return comp_str - - def benchmark_all_csv( - self, - inputs: tuple, - modelname, - dynamic, - device_str, - frontend, - import_args, - mode="native", - ): - self.setup_cl(inputs) - self.import_args = import_args - self.mode = mode - field_names = [ - "model", - "batch_size", - "engine", - "dialect", - "device", - "shape_type", - "data_type", - "iter/sec", - "ms/iter", - "vs. PyTorch/TF", - "iterations", - "param_count", - "tags", - "notes", - "datetime", - "host_memory_mb", - "device_memory_mb", - "measured_host_memory_mb", - "measured_device_memory_mb", - ] - # "frontend" must be the first element. - if self.mode == "native": - engines = ["shark_python", "shark_iree_c"] - if self.mode == "baseline": - engines = ["frontend"] - if self.mode == "all": - engines = ["frontend", "shark_python", "shark_iree_c"] - - if shark_args.onnx_bench == True: - engines.append("onnxruntime") - - if not os.path.exists("bench_results.csv"): - with open("bench_results.csv", mode="w", newline="") as f: - writer = csv.writer(f) - writer.writerow(field_names) - - with open("bench_results.csv", mode="a", newline="") as f: - writer = csv.DictWriter(f, fieldnames=field_names) - bench_info = {} - bench_info["model"] = modelname - bench_info["batch_size"] = str(import_args["batch_size"]) - bench_info["dialect"] = self.mlir_dialect - bench_info["iterations"] = shark_args.num_iterations - if dynamic == True: - bench_info["shape_type"] = "dynamic" - else: - bench_info["shape_type"] = "static" - bench_info["device"] = device_str - if "fp16" in modelname: - bench_info["data_type"] = "float16" - else: - bench_info["data_type"] = inputs[0].dtype - - for e in engines: - engine_result = {} - self.frontend_result = None - if e == "frontend": - engine_result["engine"] = frontend - if check_requirements(frontend): - ( - engine_result["iter/sec"], - engine_result["ms/iter"], - engine_result["host_memory_mb"], - engine_result["device_memory_mb"], - ) = self.benchmark_frontend(modelname) - self.frontend_result = engine_result["ms/iter"] - engine_result["vs. PyTorch/TF"] = "baseline" - ( - engine_result["param_count"], - engine_result["tags"], - engine_result["notes"], - ) = self.get_metadata(modelname) - else: - self.frontend_result = None - continue - - elif e == "shark_python": - engine_result["engine"] = "shark_python" - ( - engine_result["iter/sec"], - engine_result["ms/iter"], - ) = self.benchmark_python(inputs) - - engine_result[ - "vs. PyTorch/TF" - ] = self.compare_bench_results( - self.frontend_result, engine_result["ms/iter"] - ) - - elif e == "shark_iree_c": - engine_result["engine"] = "shark_iree_c" - ( - engine_result["iter/sec"], - engine_result["ms/iter"], - engine_result["host_memory_mb"], - engine_result["device_memory_mb"], - ) = self.benchmark_c() - - engine_result[ - "vs. PyTorch/TF" - ] = self.compare_bench_results( - self.frontend_result, engine_result["ms/iter"] - ) - - elif e == "onnxruntime": - engine_result["engine"] = "onnxruntime" - ( - engine_result["iter/sec"], - engine_result["ms/iter"], - ) = self.benchmark_onnx(modelname, inputs) - - engine_result["datetime"] = str(datetime.now()) - writer.writerow(bench_info | engine_result) diff --git a/shark/shark_compile.py b/shark/shark_compile.py deleted file mode 100644 index 49b003defa..0000000000 --- a/shark/shark_compile.py +++ /dev/null @@ -1,241 +0,0 @@ -import os -import tempfile -from shark.shark_inference import SharkInference -from shark.shark_importer import import_with_fx, save_mlir -import torch -import torch_mlir -from torch_mlir.compiler_utils import run_pipeline_with_repro_report -from typing import List, Tuple -from io import BytesIO -from brevitas_examples.common.generative.quantize import quantize_model -from brevitas_examples.llm.llm_quant.run_utils import get_model_impl - - -# fmt: off -def quant〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]: - if len(lhs) == 3 and len(rhs) == 2: - return [lhs[0], lhs[1], rhs[0]] - elif len(lhs) == 2 and len(rhs) == 2: - return [lhs[0], rhs[0]] - else: - raise ValueError("Input shapes not supported.") - - -def quant〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int: - # output dtype is the dtype of the lhs float input - lhs_rank, lhs_dtype = lhs_rank_dtype - return lhs_dtype - - -def quant〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None: - return - - -brevitas_matmul_rhs_group_quant_library = [ - quant〇matmul_rhs_group_quant〡shape, - quant〇matmul_rhs_group_quant〡dtype, - quant〇matmul_rhs_group_quant〡has_value_semantics] -# fmt: on - - -def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]): - vmfb_path = os.path.join(os.getcwd(), extended_model_name + ".vmfb") - shark_module = None - if os.path.isfile(vmfb_path): - shark_module = SharkInference( - None, - device=device, - mlir_dialect=mlir_dialect, - ) - print(f"loading existing vmfb from: {vmfb_path}") - shark_module.load_module(vmfb_path, extra_args=extra_args) - return shark_module - - -def compile_module( - shark_module, extended_model_name, generate_vmfb, extra_args=[] -): - if generate_vmfb: - vmfb_path = os.path.join(os.getcwd(), extended_model_name + ".vmfb") - if os.path.isfile(vmfb_path): - print(f"loading existing vmfb from: {vmfb_path}") - shark_module.load_module(vmfb_path, extra_args=extra_args) - else: - print( - "No vmfb found. Compiling and saving to {}".format(vmfb_path) - ) - path = shark_module.save_module( - os.getcwd(), extended_model_name, extra_args - ) - shark_module.load_module(path, extra_args=extra_args) - else: - shark_module.compile(extra_args) - return shark_module - - -def compile_int_precision( - model, inputs, precision, device, generate_vmfb, extended_model_name -): - weight_bit_width = 4 if precision == "int4" else 8 - weight_group_size = 128 - quantize_model( - get_model_impl(model), - dtype=torch.float32, - weight_quant_type="asym", - weight_bit_width=weight_bit_width, - weight_param_method="stats", - weight_scale_precision="float_scale", - weight_quant_granularity="per_group", - weight_group_size=weight_group_size, - quantize_weight_zero_point=False, - input_bit_width=None, - input_scale_type="float", - input_param_method="stats", - input_quant_type="asym", - input_quant_granularity="per_tensor", - quantize_input_zero_point=False, - seqlen=2048, - ) - print("Weight quantization applied.") - torchscript_module = import_with_fx( - model, - inputs, - precision=precision, - mlir_type="torchscript", - ) - mlir_module = torch_mlir.compile( - torchscript_module, - inputs, - output_type="torch", - backend_legal_ops=["quant.matmul_rhs_group_quant"], - extra_library=brevitas_matmul_rhs_group_quant_library, - use_tracing=False, - verbose=False, - ) - print(f"[DEBUG] converting torch to linalg") - run_pipeline_with_repro_report( - mlir_module, - "builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)", - description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR", - ) - from contextlib import redirect_stdout - - mlir_file_path = os.path.join( - os.getcwd(), f"{extended_model_name}_linalg.mlir" - ) - with open(mlir_file_path, "w") as f: - with redirect_stdout(f): - print(mlir_module.operation.get_asm()) - mlir_module = str(mlir_module) - mlir_module = mlir_module.encode("UTF-8") - mlir_module = BytesIO(mlir_module) - bytecode = mlir_module.read() - bytecode_path = os.path.join( - os.getcwd(), f"{extended_model_name}_linalg.mlirbc" - ) - with open(bytecode_path, "wb") as f: - f.write(bytecode) - del bytecode - del mlir_module - print(f"Elided IR written for {extended_model_name}") - return bytecode_path - shark_module = SharkInference( - mlir_module=bytecode_path, device=device, mlir_dialect="tm_tensor" - ) - extra_args = [ - "--iree-hal-dump-executable-sources-to=ies", - "--iree-vm-target-truncate-unsupported-floats", - "--iree-codegen-check-ir-before-llvm-conversion=false", - "--iree-vm-bytecode-module-output-format=flatbuffer-binary", - ] - return ( - compile_module( - shark_module, - extended_model_name=extended_model_name, - generate_vmfb=generate_vmfb, - extra_args=extra_args, - ), - bytecode_path, - ) - - -def shark_compile_through_fx( - model, - inputs, - extended_model_name, - precision, - f16_input_mask=None, - save_dir=tempfile.gettempdir(), - debug=False, - generate_or_load_vmfb=True, - extra_args=[], - device=None, - mlir_dialect="tm_tensor", -): - is_f16 = precision == "fp16" - if generate_or_load_vmfb: - shark_module = load_vmfb( - extended_model_name=extended_model_name, - device=device, - mlir_dialect=mlir_dialect, - extra_args=extra_args, - ) - if shark_module: - return ( - shark_module, - None, - ) - - from shark.parser import shark_args - - if "cuda" in device: - shark_args.enable_tf32 = True - - if precision in ["int4", "int8"]: - mlir_module = compile_int_precision( - model, - inputs, - precision, - device, - generate_or_load_vmfb, - extended_model_name, - ) - extra_args = [ - "--iree-hal-dump-executable-sources-to=ies", - "--iree-vm-target-truncate-unsupported-floats", - "--iree-codegen-check-ir-before-llvm-conversion=false", - "--iree-vm-bytecode-module-output-format=flatbuffer-binary", - ] - else: - ( - bytecode, - _, - ) = import_with_fx( - model=model, - inputs=inputs, - is_f16=is_f16, - f16_input_mask=f16_input_mask, - debug=debug, - model_name=extended_model_name, - save_dir=save_dir, - ) - mlir_module = save_mlir( - mlir_module=bytecode, - model_name=extended_model_name, - mlir_dialect=mlir_dialect, - ) - - shark_module = SharkInference( - mlir_module, - device=device, - mlir_dialect=mlir_dialect, - ) - return ( - compile_module( - shark_module, - extended_model_name, - generate_vmfb=generate_or_load_vmfb, - extra_args=extra_args, - ), - mlir_module, - ) diff --git a/shark/shark_downloader.py b/shark/shark_downloader.py deleted file mode 100644 index b9baf32595..0000000000 --- a/shark/shark_downloader.py +++ /dev/null @@ -1,297 +0,0 @@ -# Lint as: python3 -"""SHARK Downloader""" -# Requirements : Put shark_tank in SHARK directory -# /SHARK -# /gen_shark_tank -# /tflite -# /albert_lite_base -# /...model_name... -# /tf -# /pytorch -# -# -# - -import numpy as np -import os -from tqdm.std import tqdm -import sys -from pathlib import Path -from shark.parser import shark_args -from google.cloud import storage - - -def download_public_file( - full_gs_url, destination_folder_name, single_file=False -): - """Downloads a public blob from the bucket.""" - # bucket_name = "gs://your-bucket-name/path/to/file" - # destination_file_name = "local/path/to/file" - - storage_client = storage.Client.create_anonymous_client() - bucket_name = full_gs_url.split("/")[2] - source_blob_name = None - dest_filename = None - desired_file = None - if single_file: - desired_file = full_gs_url.split("/")[-1] - source_blob_name = "/".join(full_gs_url.split("/")[3:-1]) - destination_folder_name, dest_filename = os.path.split( - destination_folder_name - ) - else: - source_blob_name = "/".join(full_gs_url.split("/")[3:]) - bucket = storage_client.bucket(bucket_name) - blobs = bucket.list_blobs(prefix=source_blob_name) - if not os.path.exists(destination_folder_name): - os.mkdir(destination_folder_name) - for blob in blobs: - blob_name = blob.name.split("/")[-1] - if single_file: - if blob_name == desired_file: - destination_filename = os.path.join( - destination_folder_name, dest_filename - ) - with open(destination_filename, "wb") as f: - with tqdm.wrapattr( - f, "write", total=blob.size - ) as file_obj: - storage_client.download_blob_to_file(blob, file_obj) - else: - continue - - else: - destination_filename = os.path.join( - destination_folder_name, blob_name - ) - if os.path.isdir(destination_filename): - continue - with open(destination_filename, "wb") as f: - with tqdm.wrapattr(f, "write", total=blob.size) as file_obj: - storage_client.download_blob_to_file(blob, file_obj) - - -input_type_to_np_dtype = { - "float32": np.float32, - "float64": np.float64, - "bool": np.bool_, - "int32": np.int32, - "int64": np.int64, - "uint8": np.uint8, - "int8": np.int8, -} - -# Save the model in the home local so it needn't be fetched everytime in the CI. -home = str(Path.home()) -alt_path = os.path.join(os.path.dirname(__file__), "../gen_shark_tank/") -custom_path = shark_args.local_tank_cache - -if custom_path is not None: - if not os.path.exists(custom_path): - os.mkdir(custom_path) - - WORKDIR = custom_path - - print(f"Using {WORKDIR} as local shark_tank cache directory.") - -elif os.path.exists(alt_path): - WORKDIR = alt_path - print( - f"Using {WORKDIR} as shark_tank directory. Delete this directory if you aren't working from locally generated shark_tank." - ) -else: - WORKDIR = os.path.join(home, ".local/shark_tank/") - print( - f"shark_tank local cache is located at {WORKDIR} . You may change this by setting the --local_tank_cache= flag" - ) -os.makedirs(WORKDIR, exist_ok=True) - - -# Checks whether the directory and files exists. -def check_dir_exists(model_name, frontend="torch", dynamic=""): - model_dir = os.path.join(WORKDIR, model_name) - - # Remove the _tf keyword from end only for non-SD models. - if not any(model in model_name for model in ["clip", "unet", "vae"]): - if frontend in ["tf", "tensorflow"]: - model_name = model_name[:-3] - elif frontend in ["tflite"]: - model_name = model_name[:-7] - elif frontend in ["torch", "pytorch"]: - model_name = model_name[:-6] - - model_mlir_file_name = f"{model_name}{dynamic}_{frontend}.mlir" - - if os.path.isdir(model_dir): - if ( - os.path.isfile(os.path.join(model_dir, model_mlir_file_name)) - and os.path.isfile(os.path.join(model_dir, "function_name.npy")) - and os.path.isfile(os.path.join(model_dir, "inputs.npz")) - and os.path.isfile(os.path.join(model_dir, "golden_out.npz")) - and os.path.isfile(os.path.join(model_dir, "hash.npy")) - ): - print( - f"""Model artifacts for {model_name} found at {WORKDIR}...""" - ) - return True - return False - - -def _internet_connected(): - import requests as req - - try: - req.get("http://1.1.1.1") - return True - except: - return False - - -def get_git_revision_short_hash() -> str: - import subprocess - - if shark_args.shark_prefix is not None: - prefix_kw = shark_args.shark_prefix - else: - import json - - dir_path = os.path.dirname(os.path.realpath(__file__)) - src = os.path.join(dir_path, "..", "tank_version.json") - with open(src, "r") as f: - data = json.loads(f.read()) - prefix_kw = data["version"] - print(f"Checking for updates from gs://shark_tank/{prefix_kw}") - return prefix_kw - - -def get_sharktank_prefix(): - tank_prefix = "" - if not _internet_connected(): - print( - "No internet connection. Using the model already present in the tank." - ) - tank_prefix = "none" - else: - desired_prefix = get_git_revision_short_hash() - storage_client_a = storage.Client.create_anonymous_client() - base_bucket_name = "shark_tank" - base_bucket = storage_client_a.bucket(base_bucket_name) - dir_blobs = base_bucket.list_blobs(prefix=f"{desired_prefix}") - for blob in dir_blobs: - dir_blob_name = blob.name.split("/") - if desired_prefix in dir_blob_name[0]: - tank_prefix = dir_blob_name[0] - break - else: - continue - if tank_prefix == "": - print( - f"shark_tank bucket not found matching ({desired_prefix}). Defaulting to nightly." - ) - tank_prefix = "nightly" - return tank_prefix - - -# Downloads the torch model from gs://shark_tank dir. -def download_model( - model_name, - dynamic=False, - tank_url=None, - frontend=None, - tuned=None, - import_args={"batch_size": 1}, -): - model_name = model_name.replace("/", "_") - dyn_str = "_dynamic" if dynamic else "" - os.makedirs(WORKDIR, exist_ok=True) - shark_args.shark_prefix = get_sharktank_prefix() - if import_args["batch_size"] and import_args["batch_size"] != 1: - model_dir_name = ( - model_name - + "_" - + frontend - + "_BS" - + str(import_args["batch_size"]) - ) - elif any(model in model_name for model in ["clip", "unet", "vae"]): - # TODO(Ean Garvey): rework extended naming such that device is only included in model_name after .vmfb compilation. - model_dir_name = model_name - else: - model_dir_name = model_name + "_" + frontend - model_dir = os.path.join(WORKDIR, model_dir_name) - - if not tank_url: - tank_url = "gs://shark_tank/" + shark_args.shark_prefix - - full_gs_url = tank_url.rstrip("/") + "/" + model_dir_name - if not check_dir_exists( - model_dir_name, frontend=frontend, dynamic=dyn_str - ): - print( - f"Downloading artifacts for model {model_name} from: {full_gs_url}" - ) - download_public_file(full_gs_url, model_dir) - - elif shark_args.force_update_tank == True: - print( - f"Force-updating artifacts for model {model_name} from: {full_gs_url}" - ) - download_public_file(full_gs_url, model_dir) - else: - if not _internet_connected(): - print( - "No internet connection. Using the model already present in the tank." - ) - else: - local_hash = str(np.load(os.path.join(model_dir, "hash.npy"))) - gs_hash_url = ( - tank_url.rstrip("/") + "/" + model_dir_name + "/hash.npy" - ) - download_public_file( - gs_hash_url, - os.path.join(model_dir, "upstream_hash.npy"), - single_file=True, - ) - try: - upstream_hash = str( - np.load(os.path.join(model_dir, "upstream_hash.npy")) - ) - except FileNotFoundError: - print(f"Model artifact hash not found at {model_dir}.") - upstream_hash = None - if local_hash != upstream_hash and shark_args.update_tank == True: - print(f"Updating artifacts for model {model_name}...") - download_public_file(full_gs_url, model_dir) - - elif local_hash != upstream_hash: - print( - "Hash does not match upstream in gs://shark_tank/. If you want to use locally generated artifacts, this is working as intended. Otherwise, run with --update_tank." - ) - else: - print( - "Local and upstream hashes match. Using cached model artifacts." - ) - - model_dir = os.path.join(WORKDIR, model_dir_name) - tuned_str = "" if tuned is None else "_" + tuned - suffix = f"{dyn_str}_{frontend}{tuned_str}.mlir" - mlir_filename = os.path.join(model_dir, model_name + suffix) - print( - f"Verifying that model artifacts were downloaded successfully to {mlir_filename}..." - ) - if not os.path.exists(mlir_filename): - from tank.generate_sharktank import gen_shark_files - - print( - "The model data was not found. Trying to generate artifacts locally." - ) - gen_shark_files(model_name, frontend, WORKDIR, import_args) - - assert os.path.exists(mlir_filename), f"MLIR not found at {mlir_filename}" - function_name = str(np.load(os.path.join(model_dir, "function_name.npy"))) - inputs = np.load(os.path.join(model_dir, "inputs.npz")) - golden_out = np.load(os.path.join(model_dir, "golden_out.npz")) - - inputs_tuple = tuple([inputs[key] for key in inputs]) - golden_out_tuple = tuple([golden_out[key] for key in golden_out]) - return mlir_filename, function_name, inputs_tuple, golden_out_tuple diff --git a/shark/shark_eager/shark_eager.py b/shark/shark_eager/shark_eager.py deleted file mode 100644 index bd75119947..0000000000 --- a/shark/shark_eager/shark_eager.py +++ /dev/null @@ -1,212 +0,0 @@ -from typing import Any, Dict, List, Tuple -from collections import defaultdict -from shark.shark_importer import import_with_fx, save_mlir -import torchvision.models as models -import copy -import io -import numpy as np -import sys -import torch -import torch.fx -from torch.fx.node import Node -from typing import Dict -import torch_mlir - - -def shark_backend(fx_g: torch.fx.GraphModule, inputs, device: str = "cpu"): - mlir_module = torch_mlir.compile( - fx_g, inputs, output_type="linalg-on-tensors" - ) - bytecode_stream = io.BytesIO() - mlir_module.operation.write_bytecode(bytecode_stream) - bytecode = bytecode_stream.getvalue() - bytecode_path = save_mlir( - bytecode, - model_name="shark_eager_module", - frontend="torch", - mlir_dialect="tm_tensor", - ) - from shark.shark_inference import SharkInference - - shark_module = SharkInference( - mlir_module=bytecode_path, - device=device, - mlir_dialect="tm_tensor", - ) - shark_module.compile(extra_args=[]) - return shark_module - - -def _make_single_op_gm(node, captured_val, compiled_graph): - """Make a GraphModule that just executes the given node.""" - g = torch.fx.Graph() - env = {} - inputs = [] - for arg in node.args: - if arg and hasattr(arg, "name"): - env[arg.name] = g.placeholder(arg.name) - if isinstance(captured_val[arg.name], (list, tuple)): - for val in captured_val[arg.name]: - inputs.append(val) - else: - inputs.append(captured_val[arg.name]) - - call = g.node_copy(node, lambda n: env[n.name]) - g.output(call) - g.lint() - single_node = torch.fx.GraphModule(torch.nn.Module(), g) - compiled_module = shark_backend(single_node, inputs) - compiled_graph[node.name] = { - "module": compiled_module, - "inputs": [i for i in env], - "result": None, - } - return - - -def compiled_graph(gm: torch.fx.GraphModule, attr_info): - compiled_graph = {} - g = gm.graph - for node in g.nodes: - if node.op == "call_function": - if not ( - node.target in [torch.ops.aten.empty] - or node.name.startswith("getitem") - ): - _make_single_op_gm(node, attr_info, compiled_graph) - - # Currently torch.aten.empty has an compilation issue, so running natively. - elif node.target in [torch.ops.aten.empty]: - compiled_graph[node.name] = { - "target": node.target, - "args": node.args, - "kwargs": node.kwargs, - "result": None, - } - # Get item is a simple case takes a tuple and return the tensor at a particular index. - elif node.name.startswith("getitem"): - compiled_graph[node.name] = { - "input": node.args[0].name, - "pos": node.args[1], - "result": None, - } - - return compiled_graph - - -class ShapeProp: - """ - Shape propagation. This class takes a `GraphModule`. - Then, its `propagate` method executes the `GraphModule` - node-by-node with the given arguments. As each operation - executes, the ShapeProp class stores away the shape and - element type for the output values of each operation on - the `shape` and `dtype` attributes of the operation's - `Node`. - """ - - def __init__(self, mod): - self.mod = mod - self.graph = mod.graph - self.modules = dict(self.mod.named_modules()) - - def propagate(self, *args): - args_iter = iter(args) - env: Dict[str, Node] = {} - - def load_arg(a): - return torch.fx.graph.map_arg(a, lambda n: env[n.name]) - - def fetch_attr(target: str): - target_atoms = target.split(".") - attr_itr = self.mod - for i, atom in enumerate(target_atoms): - if not hasattr(attr_itr, atom): - raise RuntimeError( - f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}" - ) - attr_itr = getattr(attr_itr, atom) - return attr_itr - - for node in self.graph.nodes: - if node.op == "placeholder": - result = next(args_iter) - elif node.op == "get_attr": - result = fetch_attr(node.target) - elif node.op == "call_function": - result = node.target( - *load_arg(node.args), **load_arg(node.kwargs) - ) - elif node.op == "call_method": - self_obj, *args = load_arg(node.args) - kwargs = load_arg(node.kwargs) - result = getattr(self_obj, node.target)(*args, **kwargs) - elif node.op == "call_module": - result = self.modules[node.target]( - *load_arg(node.args), **load_arg(node.kwargs) - ) - - # This is the only code specific to shape propagation. - # you can delete this `if` branch and this becomes - # a generic GraphModule interpreter. - if isinstance(result, torch.Tensor): - node.shape = result.shape - node.dtype = result.dtype - - env[node.name] = result - - return env - - # return load_arg(self.graph.result) - - -resnet18 = models.resnet18(pretrained=True) -resnet18.train(False) -input = (torch.randn(1, 3, 224, 224),) - -print(resnet18(input[0])) - -fx_graph = import_with_fx(resnet18, input, mlir_type="fx") - -shape_prop = ShapeProp(fx_graph) - -x = shape_prop.propagate(input[0]) - -shark_graph = compiled_graph(fx_graph, x) - - -for key in shark_graph: - if key.startswith("getitem"): - input_val = shark_graph[key]["input"] - pos = shark_graph[key]["pos"] - if input_val not in shark_graph: - shark_graph[key]["result"] = x[input_val][pos].detach() - else: - shark_graph[key]["result"] = shark_graph[input_val]["result"][ - pos - ].detach() - elif key.startswith("empty"): - operator = shark_graph[key]["target"] - args = shark_graph[key]["args"] - kwargs = shark_graph[key]["kwargs"] - shark_graph[key]["result"] = operator(*args, **kwargs).detach() - else: - input_val = shark_graph[key]["inputs"] - input_tensors = [] - for input in input_val: - if input not in shark_graph: - input_tensors.append(x[input].detach()) - else: - input_tensors.append(shark_graph[input]["result"]) - - val = shark_graph[key]["module"]("forward", input_tensors) - if isinstance(val, (tuple, list)): - list_val = [] - for v in val: - list_val.append(torch.from_numpy(v)) - shark_graph[key]["result"] = list_val - else: - shark_graph[key]["result"] = torch.from_numpy(val) - - -print(shark_graph) diff --git a/shark/shark_generate_model_config.py b/shark/shark_generate_model_config.py deleted file mode 100644 index 9847b11603..0000000000 --- a/shark/shark_generate_model_config.py +++ /dev/null @@ -1,153 +0,0 @@ -import re -import json -import numpy as np - -import torch_mlir -from iree.compiler import compile_file -from shark.shark_importer import import_with_fx, get_f16_inputs, save_mlir - - -class GenerateConfigFile: - def __init__( - self, - model, - num_sharding_stages: int, - sharding_stages_id: list[str], - units_in_each_stage: list[int], - model_input=None, - config_file_path="model_config.json", - ): - self.model = model - self.num_sharding_stages = num_sharding_stages - self.sharding_stages_id = sharding_stages_id - assert self.num_sharding_stages == len( - self.sharding_stages_id - ), "Number of sharding stages should be equal to the list of their ID" - self.model_input = model_input - self.config_file_path = config_file_path - # (Nithin) this is a quick fix - revisit and rewrite - self.units_in_each_stage = np.array(units_in_each_stage) - self.track_loop = np.zeros(len(self.sharding_stages_id)).astype(int) - - def split_into_dispatches( - self, - backend, - fx_tracing_required=False, - f16_model=False, - torch_mlir_tracing=True, - ): - graph_for_compilation = self.model - if fx_tracing_required: - graph_for_compilation = import_with_fx( - self.model, - self.model_input, - is_f16=f16_model, - f16_input_mask=[False, False], - mlir_type="torchscript", - ) - - module = torch_mlir.compile( - graph_for_compilation, - (self.model_input), - torch_mlir.OutputType.LINALG_ON_TENSORS, - use_tracing=torch_mlir_tracing, - verbose=False, - ) - module = module.operation.get_asm(large_elements_limit=4) - module_file = save_mlir( - module, - model_name="module_pre_split", - frontend="torch", - mlir_dialect="linalg", - ) - compiled_module_str = str( - compile_file( - module_file, - target_backends=[backend], - extra_args=[ - "--compile-to=flow", - "--mlir-elide-elementsattrs-if-larger=4", - ], - ) - ) - - substring_start_idx = [ - m.start() - for m in re.finditer("flow.dispatch @", compiled_module_str) - ] - dispatch_list = dict() - - # dispatch_no is the 'i'th index of a dispatch out of n total dispatches of a model - # dispatch_id is the unique id of a dispatch, multiple instances of the same dispatch - # can occur in a model - for dispatch_no, substring_idx in enumerate(substring_start_idx): - dispatch_idx = ( - compiled_module_str[substring_idx:] - .split(":")[0] - .split("@")[-1] - ) - key = "dispatch_no_" + str(dispatch_no) - dispatch_list[key] = {n: "None" for n in self.sharding_stages_id} - dispatch_list[key]["dispatch_id"] = dispatch_idx - - self.generate_json(dispatch_list) - - def split_into_layers(self): - model_dictionary = dict() - - for name, m in self.model.named_modules(): - if name == "": - continue - - # Remove non-leaf nodes from the config as they aren't an operation - substring_before_final_period = name.split(".")[:-1] - substring_before_final_period = ".".join( - substring_before_final_period - ) - if substring_before_final_period in model_dictionary: - del model_dictionary[substring_before_final_period] - - # layer_dict = {n: "None" for n in self.sharding_stages_id} - - # By default embed increasing device id's for each layer - increasing_wraparound_idx_list = ( - self.track_loop % self.units_in_each_stage - ) - layer_dict = { - n: int(increasing_wraparound_idx_list[idx][0][0]) - for idx, n in enumerate(self.sharding_stages_id) - } - self.track_loop += 1 - model_dictionary[name] = layer_dict - - self.generate_json(model_dictionary) - - def generate_json(self, artifacts): - with open(self.config_file_path, "w") as outfile: - json.dump(artifacts, outfile) - - -if __name__ == "__main__": - import torch - from transformers import AutoTokenizer - - hf_model_path = "TheBloke/vicuna-7B-1.1-HF" - tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False) - compilation_prompt = "".join(["0" for _ in range(17)]) - compilation_input_ids = tokenizer( - compilation_prompt, - return_tensors="pt", - ).input_ids - compilation_input_ids = torch.tensor(compilation_input_ids).reshape( - [1, 19] - ) - firstVicunaCompileInput = (compilation_input_ids,) - from apps.language_models.src.model_wrappers.vicuna_model import ( - FirstVicuna, - SecondVicuna7B, - CombinedModel, - ) - - model = CombinedModel() - c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput) - c.split_into_layers() diff --git a/shark/shark_importer.py b/shark/shark_importer.py deleted file mode 100644 index e05b81fa0d..0000000000 --- a/shark/shark_importer.py +++ /dev/null @@ -1,819 +0,0 @@ -# Lint as: python3 -"""SHARK Importer""" - -import sys -import tempfile -import os -import hashlib - -from apps.shark_studio.modules.shared_cmd_opts import cmd_opts - -def create_hash(file_name): - with open(file_name, "rb") as f: - file_hash = hashlib.blake2b(digest_size=64) - while chunk := f.read(2**10): - file_hash.update(chunk) - - return file_hash.hexdigest() - - -# List of the supported frontends. -supported_frontends = { - "tensorflow", - "tf", - "pytorch", - "torch", - "tf-lite", - "tflite", -} - - -class SharkImporter: - """ - SharkImporter converts frontend modules into a - mlir_module. The supported frameworks are tensorflow, - pytorch, and tf-lite. - - ... - - Attributes - ---------- - module : - torch, tensorflow or tf-lite module. - inputs : - inputs to the module, may be required for the shape - information. - frontend: str - frontend to which the module belongs. - raw_model_file: str - temp tflite model path - - Methods - ------- - import_mlir(is_dynamic, tracing_required, func_name): - is_dynamic: input shapes to be totally dynamic (pytorch specific). - tracing_required: whether tracing is required (pytorch specific. - func_name: The function to be traced out or imported to mlir. - - import_debug(is_dynamic, tracing_required, func_name): - returns the converted (mlir_module,func_name) with inputs and golden - outputs. - The inputs and outputs are converted into np array. - """ - - def __init__( - self, - module, - inputs: tuple = (), - frontend: str = "torch", - raw_model_file: str = "", - return_str: bool = False, - ): - self.module = module - self.inputs = None if len(inputs) == 0 else inputs - self.frontend = frontend - if not self.frontend in supported_frontends: - print( - f"The frontend is not in the supported_frontends: {supported_frontends}" - ) - sys.exit(1) - self.raw_model_file = raw_model_file - self.return_str = return_str - - # NOTE: The default function for torch is "forward" and tf-lite is "main". - - def _torch_mlir(self, is_dynamic, tracing_required, mlir_type): - from shark.torch_mlir_utils import get_torch_mlir_module - - return get_torch_mlir_module( - self.module, - self.inputs, - is_dynamic, - tracing_required, - self.return_str, - mlir_type, - ) - - def _tf_mlir(self, func_name, save_dir="."): - from iree.compiler import tf as tfc - - return tfc.compile_module( - self.module, - exported_names=[func_name], - import_only=True, - output_file=save_dir, - ) - - def _tflite_mlir(self, func_name, save_dir="."): - from iree.compiler import tflite as tflitec - - self.mlir_model = tflitec.compile_file( - self.raw_model_file, # in tflite, it is a path to .tflite file, not a tflite interpreter - input_type="tosa", - import_only=True, - output_file=save_dir, - ) - return self.mlir_model - - # Adds the conversion of the frontend with the private function. - def import_mlir( - self, - is_dynamic=False, - tracing_required=False, - func_name="forward", - save_dir=cmd_opts.tmp_dir, #"./shark_tmp/", - mlir_type="linalg", - ): - if self.frontend in ["torch", "pytorch"]: - if self.inputs == None: - print( - "Please pass in the inputs, the inputs are required to determine the shape of the mlir_module" - ) - sys.exit(1) - return ( - self._torch_mlir(is_dynamic, tracing_required, mlir_type), - func_name, - ) - if self.frontend in ["tf", "tensorflow"]: - return self._tf_mlir(func_name, save_dir), func_name - if self.frontend in ["tflite", "tf-lite"]: - func_name = "main" - return self._tflite_mlir(func_name, save_dir), func_name - - # Converts the frontend specific tensors into np array. - def convert_to_numpy(self, array_tuple: tuple): - if self.frontend in ["torch", "pytorch"]: - return [x.detach().cpu().numpy() for x in array_tuple] - if self.frontend in ["tf", "tensorflow"]: - return [x.numpy() for x in array_tuple] - - # Saves `function_name.npy`, `inputs.npz`, `golden_out.npz` and `model_name.mlir` in the directory `dir`. - def save_data( - self, - dir, - model_name, - mlir_data, - func_name, - inputs, - outputs, - mlir_type="linalg", - ): - import numpy as np - - inputs_name = "inputs.npz" - outputs_name = "golden_out.npz" - func_file_name = "function_name" - model_name_mlir = ( - model_name + "_" + self.frontend + "_" + mlir_type + ".mlir" - ) - print(f"saving {model_name_mlir} to {dir}") - try: - inputs = [x.cpu().detach() for x in inputs] - except AttributeError: - try: - inputs = [x.numpy() for x in inputs] - except AttributeError: - inputs = [x for x in inputs] - np.savez(os.path.join(dir, inputs_name), *inputs) - np.savez(os.path.join(dir, outputs_name), *outputs) - np.save(os.path.join(dir, func_file_name), np.array(func_name)) - if self.frontend == "torch": - with open(os.path.join(dir, model_name_mlir), "wb") as mlir_file: - mlir_file.write(mlir_data) - hash_gen_attempts = 2 - for i in range(hash_gen_attempts): - try: - mlir_hash = create_hash(os.path.join(dir, model_name_mlir)) - except FileNotFoundError as err: - if i < hash_gen_attempts: - continue - else: - raise err - - np.save(os.path.join(dir, "hash"), np.array(mlir_hash)) - return - - def import_debug( - self, - is_dynamic=False, - tracing_required=False, - func_name="forward", - dir=tempfile.gettempdir(), - model_name="model", - golden_values=None, - mlir_type="linalg", - ): - if self.inputs == None: - print( - f"There is no input provided: {self.inputs}, please provide inputs or simply run import_mlir." - ) - sys.exit(1) - model_name_mlir = ( - model_name + "_" + self.frontend + "_" + mlir_type + ".mlir" - ) - artifact_path = os.path.join(dir, model_name_mlir) - imported_mlir = self.import_mlir( - is_dynamic, - tracing_required, - func_name, - save_dir=artifact_path, - mlir_type=mlir_type, - ) - # TODO: Make sure that any generic function name is accepted. Currently takes in the default function names. - # TODO: Check for multiple outputs. - if self.frontend in ["torch", "pytorch"]: - import torch - - golden_out = None - if golden_values is not None: - golden_out = golden_values - else: - golden_out = self.module(*self.inputs) - if torch.is_tensor(golden_out): - golden_out = tuple( - golden_out.detach().cpu().numpy(), - ) - else: - golden_out = self.convert_to_numpy(golden_out) - # Save the artifacts in the directory dir. - self.save_data( - dir, - model_name, - imported_mlir[0], - imported_mlir[1], - self.inputs, - golden_out, - mlir_type, - ) - return ( - imported_mlir, - self.convert_to_numpy(self.inputs), - golden_out, - ) - if self.frontend in ["tf", "tensorflow"]: - import tensorflow as tf - - golden_out = self.module.forward(*self.inputs) - if tf.is_tensor(golden_out): - golden_out = tuple( - golden_out.numpy(), - ) - elif golden_out is tuple: - golden_out = self.convert_to_numpy(golden_out) - elif hasattr(golden_out, "logits"): - # from transformers import TFSequenceClassifierOutput - golden_out = golden_out.logits - else: - golden_out = golden_out.last_hidden_state - # Save the artifacts in the directory dir. - self.save_data( - dir, - model_name, - imported_mlir[0], - imported_mlir[1], - self.inputs, - golden_out, - ) - return ( - imported_mlir, - self.convert_to_numpy(self.inputs), - golden_out, - ) - if self.frontend in ["tflite", "tf-lite"]: - # TODO(Chi): Validate it for tflite models. - golden_out = self.module.invoke_tflite(self.inputs) - self.save_data( - dir, - model_name, - imported_mlir[0], - imported_mlir[1], - self.inputs, - golden_out, - ) - return ( - imported_mlir, - self.inputs, - golden_out, - ) - - -def get_f16_inputs(inputs, is_f16, f16_input_mask): - if is_f16 == False: - return inputs - if f16_input_mask == None: - return tuple([x.half() for x in inputs]) - - f16_masked_inputs = [] - for i in range(len(inputs)): - if f16_input_mask[i]: - f16_masked_inputs.append(inputs[i].half()) - else: - f16_masked_inputs.append(inputs[i]) - - return tuple(f16_masked_inputs) - - -# Upcasts the block/list of ops. -def add_upcast(fx_g): - import torch - - for node in fx_g.graph.nodes: - if node.target in [torch.ops.aten.mul]: - # This is a very strict check. - if hasattr(node.args[1], "target"): - if ( - node.args[1].target in [torch.ops.aten.rsqrt] - and node.args[1].args[0].target in [torch.ops.aten.add] - and node.args[1].args[0].args[0].target - in [torch.ops.aten.mean] - and node.args[1].args[0].args[0].args[0].target - in [torch.ops.aten.pow] - ): - print("found an upcasting block let's upcast it.") - pow_node = node.args[1].args[0].args[0].args[0] - mul_node = node - with fx_g.graph.inserting_before(pow_node): - lhs = pow_node.args[0] - upcast_lhs = fx_g.graph.call_function( - torch.ops.aten._to_copy, - args=(lhs,), - kwargs={"dtype": torch.float32}, - ) - pow_node.args = (upcast_lhs, pow_node.args[1]) - with fx_g.graph.inserting_before(mul_node): - new_node = fx_g.graph.call_function( - torch.ops.aten._to_copy, - args=(mul_node,), - kwargs={"dtype": torch.float16}, - ) - mul_node.append(new_node) - mul_node.replace_all_uses_with(new_node) - new_node.args = (mul_node,) - new_node.kwargs = {"dtype": torch.float16} - - fx_g.graph.lint() - - -def transform_fx(fx_g, quantized=False): - import torch - - kwargs_dict = { - "dtype": torch.float16, - "device": torch.device(type="cpu"), - "pin_memory": False, - } - kwargs_dict1 = { - "dtype": torch.float16, - } - for node in fx_g.graph.nodes: - if node.op == "call_function": - # aten.empty should be filled with zeros. - if node.target in [torch.ops.aten.empty]: - with fx_g.graph.inserting_after(node): - new_node = fx_g.graph.call_function( - torch.ops.aten.zero_, - args=(node,), - ) - node.append(new_node) - node.replace_all_uses_with(new_node) - new_node.args = (node,) - if quantized: - continue - - if node.target in [ - torch.ops.aten.arange, - torch.ops.aten.empty, - torch.ops.aten.zeros, - torch.ops.aten.zeros_like, - ]: - if node.kwargs.get("dtype") == torch.float32: - node.kwargs = kwargs_dict - - # Vicuna - if node.target in [ - torch.ops.aten._to_copy, - ]: - if node.kwargs.get("dtype") == torch.float32: - node.kwargs = kwargs_dict1 - - if node.target in [ - torch.ops.aten.masked_fill, - ]: - if node.args[2] > torch.finfo(torch.half).max: - max_val = torch.finfo(torch.half).max - node.args = (node.args[0], node.args[1], max_val) - elif node.args[2] < torch.finfo(torch.half).min: - min_val = torch.finfo(torch.half).min - node.args = (node.args[0], node.args[1], min_val) - - if node.target in [ - torch.ops.aten.full, - ]: - if node.args[1] > torch.finfo(torch.half).max: - max_val = torch.finfo(torch.half).max - node.args = (node.args[0], max_val) - node.kwargs = kwargs_dict - elif node.args[1] < torch.finfo(torch.half).min: - min_val = torch.finfo(torch.half).min - node.args = (node.args[0], min_val) - node.kwargs = kwargs_dict - - # Inputs and outputs of aten.var.mean should be upcasted to fp32. - if node.target in [torch.ops.aten.var_mean]: - with fx_g.graph.inserting_before(node): - new_node = fx_g.graph.call_function( - torch.ops.prims.convert_element_type, - args=(node.args[0], torch.float32), - kwargs={}, - ) - node.args = (new_node, node.args[1]) - - if node.name.startswith("getitem"): - with fx_g.graph.inserting_before(node): - if node.args[0].target in [torch.ops.aten.var_mean]: - new_node = fx_g.graph.call_function( - torch.ops.aten._to_copy, - args=(node,), - kwargs={"dtype": torch.float16}, - ) - node.append(new_node) - node.replace_all_uses_with(new_node) - new_node.args = (node,) - new_node.kwargs = {"dtype": torch.float16} - - # Required for cuda debugging. - # for node in fx_g.graph.nodes: - # if node.op == "call_function": - # if node.kwargs.get("device") == torch.device(type="cpu"): - # new_kwargs = node.kwargs.copy() - # new_kwargs["device"] = torch.device(type="cuda") - # node.kwargs = new_kwargs - - fx_g.graph.lint() - - -def gptq_transforms(fx_g): - import torch - - for node in fx_g.graph.nodes: - if node.op == "call_function": - if node.target in [ - torch.ops.aten.arange, - torch.ops.aten.empty, - torch.ops.aten.ones, - torch.ops.aten._to_copy, - ]: - if node.kwargs.get("device") == torch.device(device="cuda:0"): - updated_kwargs = node.kwargs.copy() - updated_kwargs["device"] = torch.device(device="cpu") - node.kwargs = updated_kwargs - - if node.target in [ - torch.ops.aten._to_copy, - ]: - if node.kwargs.get("dtype") == torch.bfloat16: - updated_kwargs = node.kwargs.copy() - updated_kwargs["dtype"] = torch.float16 - node.kwargs = updated_kwargs - - # Inputs of aten.native_layer_norm should be upcasted to fp32. - if node.target in [torch.ops.aten.native_layer_norm]: - with fx_g.graph.inserting_before(node): - new_node_arg0 = fx_g.graph.call_function( - torch.ops.prims.convert_element_type, - args=(node.args[0], torch.float32), - kwargs={}, - ) - node.args = ( - new_node_arg0, - node.args[1], - node.args[2], - node.args[3], - node.args[4], - ) - - # Inputs of aten.mm should be upcasted to fp32. - if node.target in [torch.ops.aten.mm]: - with fx_g.graph.inserting_before(node): - new_node_arg0 = fx_g.graph.call_function( - torch.ops.prims.convert_element_type, - args=(node.args[0], torch.float32), - kwargs={}, - ) - new_node_arg1 = fx_g.graph.call_function( - torch.ops.prims.convert_element_type, - args=(node.args[1], torch.float32), - kwargs={}, - ) - node.args = (new_node_arg0, new_node_arg1) - - # Outputs of aten.mm should be downcasted to fp16. - if type(node.args[0]) == torch.fx.node.Node and node.args[ - 0 - ].target in [torch.ops.aten.mm]: - with fx_g.graph.inserting_before(node): - tmp = node.args[0] - new_node = fx_g.graph.call_function( - torch.ops.aten._to_copy, - args=(node.args[0],), - kwargs={"dtype": torch.float16}, - ) - node.args[0].append(new_node) - node.args[0].replace_all_uses_with(new_node) - new_node.args = (tmp,) - new_node.kwargs = {"dtype": torch.float16} - - # Inputs of aten._softmax should be upcasted to fp32. - if node.target in [torch.ops.aten._softmax]: - with fx_g.graph.inserting_before(node): - new_node_arg0 = fx_g.graph.call_function( - torch.ops.prims.convert_element_type, - args=(node.args[0], torch.float32), - kwargs={}, - ) - node.args = (new_node_arg0, node.args[1], node.args[2]) - - # Outputs of aten._softmax should be downcasted to fp16. - if ( - type(node.args[0]) == torch.fx.node.Node - and node.args[0].target in [torch.ops.aten._softmax] - and node.target in [torch.ops.aten.expand] - ): - with fx_g.graph.inserting_before(node): - tmp = node.args[0] - new_node = fx_g.graph.call_function( - torch.ops.aten._to_copy, - args=(node.args[0],), - kwargs={"dtype": torch.float16}, - ) - node.args[0].append(new_node) - node.args[0].replace_all_uses_with(new_node) - new_node.args = (tmp,) - new_node.kwargs = {"dtype": torch.float16} - - fx_g.graph.lint() - - -# Doesn't replace the None type. -def change_fx_graph_return_to_tuple(fx_g): - for node in fx_g.graph.nodes: - if node.op == "output": - # output nodes always have one argument - node_arg = node.args[0] - out_nodes = [] - if isinstance(node_arg, list): - # Don't return NoneType elements. - for out_node in node_arg: - if not isinstance(out_node, type(None)): - out_nodes.append(out_node) - # If there is a single tensor/element to be returned don't - # a tuple for it. - if len(out_nodes) == 1: - node.args = out_nodes - else: - node.args = (tuple(out_nodes),) - fx_g.graph.lint() - fx_g.recompile() - return fx_g - - -def flatten_training_input(inputs): - flattened_input = [] - for i in inputs: - if isinstance(i, dict): - for value in i.values(): - flattened_input.append(value.detach()) - elif isinstance(i, tuple): - for value in i: - flattened_input.append(value) - else: - flattened_input.append(i) - return tuple(flattened_input) - - -# TODO: Remove is_f16 and fix all calls with using precision instead -# Applies fx conversion to the model and imports the mlir. -def import_with_fx( - model, - inputs, - is_f16=False, - f16_input_mask=None, - debug=False, - training=False, - return_str=False, - save_dir=tempfile.gettempdir(), - model_name="model", - mlir_type="linalg", - is_dynamic=False, - tracing_required=False, - precision="fp32", - is_gptq=False, -): - import torch - from torch.fx.experimental.proxy_tensor import make_fx - from torch._decomp import get_decompositions - from typing import List - - golden_values = None - if debug: - try: - golden_values = model(*inputs) - except: - golden_values = None - - def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]: - removed_indexes = [] - for node in fx_g.graph.nodes: - if node.op == "output": - assert ( - len(node.args) == 1 - ), "Output node must have a single argument" - node_arg = node.args[0] - if isinstance(node_arg, (list, tuple)): - node_arg = list(node_arg) - node_args_len = len(node_arg) - for i in range(node_args_len): - curr_index = node_args_len - (i + 1) - if node_arg[curr_index] is None: - removed_indexes.append(curr_index) - node_arg.pop(curr_index) - node.args = (tuple(node_arg),) - break - - if len(removed_indexes) > 0: - fx_g.graph.lint() - fx_g.graph.eliminate_dead_code() - fx_g.recompile() - removed_indexes.sort() - return removed_indexes - - def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool: - """ - Replace tuple with tuple element in functions that return one-element tuples. - Returns true if an unwrapping took place, and false otherwise. - """ - unwrapped_tuple = False - for node in fx_g.graph.nodes: - if node.op == "output": - assert ( - len(node.args) == 1 - ), "Output node must have a single argument" - node_arg = node.args[0] - if isinstance(node_arg, tuple): - if len(node_arg) == 1: - node.args = (node_arg[0],) - unwrapped_tuple = True - break - - if unwrapped_tuple: - fx_g.graph.lint() - fx_g.recompile() - return unwrapped_tuple - - # TODO: Control the decompositions. - decomps_list = [ - torch.ops.aten.embedding_dense_backward, - torch.ops.aten.native_layer_norm_backward, - torch.ops.aten.slice_backward, - torch.ops.aten.select_backward, - torch.ops.aten.norm.ScalarOpt_dim, - torch.ops.aten.native_group_norm, - torch.ops.aten.upsample_bilinear2d.vec, - torch.ops.aten.split.Tensor, - torch.ops.aten.split_with_sizes, - torch.ops.aten.native_layer_norm, - torch.ops.aten.masked_fill.Tensor, - torch.ops.aten.masked_fill.Scalar, - torch.ops.aten._scaled_dot_product_flash_attention.default, - torch.ops.aten.index_add, - torch.ops.aten.index_add_, - ] - if precision in ["int4", "int8"] and not is_gptq: - from brevitas_examples.llm.llm_quant.export import ( - block_quant_layer_level_manager, - ) - from brevitas_examples.llm.llm_quant.export import ( - brevitas_layer_export_mode, - ) - from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import ( - LinearWeightBlockQuantHandlerFwd, - ) - from brevitas_examples.llm.llm_quant.export import ( - replace_call_fn_target, - ) - from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import ( - matmul_rhs_group_quant_placeholder, - ) - from brevitas.backport.fx.experimental.proxy_tensor import ( - make_fx as brevitas_make_fx, - ) - - export_context_manager = brevitas_layer_export_mode - export_class = block_quant_layer_level_manager( - export_handlers=[LinearWeightBlockQuantHandlerFwd] - ) - with export_context_manager(model, export_class): - fx_g = brevitas_make_fx( - model, - decomposition_table=get_decompositions(decomps_list), - )(*inputs) - - transform_fx(fx_g, quantized=True) - replace_call_fn_target( - fx_g, - src=matmul_rhs_group_quant_placeholder, - target=torch.ops.quant.matmul_rhs_group_quant, - ) - - fx_g.recompile() - removed_none_indexes = _remove_nones(fx_g) - was_unwrapped = _unwrap_single_tuple_return(fx_g) - else: - fx_g = make_fx( - model, - decomposition_table=get_decompositions(decomps_list), - )(*inputs) - - fx_g.graph.set_codegen(torch.fx.graph.CodeGen()) - fx_g.recompile() - - def strip_overloads(gm): - """ - Modifies the target of graph nodes in :attr:`gm` to strip overloads. - Args: - gm(fx.GraphModule): The input Fx graph module to be modified - """ - for node in gm.graph.nodes: - if isinstance(node.target, torch._ops.OpOverload): - node.target = node.target.overloadpacket - gm.recompile() - - strip_overloads(fx_g) - - if is_f16: - fx_g = fx_g.half() - transform_fx(fx_g) - # TODO: Have to make it more generic. - add_upcast(fx_g) - fx_g.recompile() - - if is_gptq: - gptq_transforms(fx_g) - fx_g.recompile() - - if mlir_type == "fx": - return fx_g - - if training: - change_fx_graph_return_to_tuple(fx_g) - inputs = flatten_training_input(inputs) - - ts_graph = torch.jit.script(fx_g) - if mlir_type == "torchscript": - return ts_graph - - inputs = get_f16_inputs(inputs, is_f16, f16_input_mask) - mlir_importer = SharkImporter( - ts_graph, - inputs, - frontend="torch", - return_str=return_str, - ) - - if debug: # and not is_f16: - (mlir_module, func_name), _, _ = mlir_importer.import_debug( - dir=save_dir, - model_name=model_name, - golden_values=golden_values, - mlir_type=mlir_type, - is_dynamic=is_dynamic, - tracing_required=tracing_required, - ) - return mlir_module, func_name - - mlir_module, func_name = mlir_importer.import_mlir(mlir_type=mlir_type) - return mlir_module, func_name - - -# Saves a .mlir module python object to the directory 'dir' with 'model_name' and returns a path to the saved file. -def save_mlir( - mlir_module, - model_name, - mlir_dialect="linalg", - frontend="torch", - dir="", -): - model_name_mlir = ( - model_name + "_" + frontend + "_" + mlir_dialect + ".mlir" - ) - if dir == "": - dir = cmd_opts.tmp_dir, #os.path.join(".", "shark_tmp") - mlir_path = os.path.join(dir, model_name_mlir) - print(f"saving {model_name_mlir} to {dir}") - if not os.path.exists(dir): - os.makedirs(dir) - if frontend == "torch": - with open(mlir_path, "wb") as mlir_file: - mlir_file.write(mlir_module) - - return mlir_path diff --git a/shark/shark_inference.py b/shark/shark_inference.py deleted file mode 100644 index 032137c089..0000000000 --- a/shark/shark_inference.py +++ /dev/null @@ -1,243 +0,0 @@ -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from shark.iree_utils.compile_utils import ( - export_iree_module_to_vmfb, - load_flatbuffer, - create_dispatch_dirs, - compile_benchmark_dirs, -) -import os -from shark.shark_runner import SharkRunner -from shark.parser import shark_args -import numpy as np - - -dtype_to_np_dtype = { - "f32": np.float32, - "f64": np.float64, - "i32": np.int32, - "i64": np.int64, - "i1": np.bool_, -} - - -class SharkInference: - """ - Runs prediction or inference on mlir_module. - - ... - - Attributes - ---------- - mlir_module : str - mlir_module or path represented in string; modules from torch-mlir are serialized in bytecode format. - device : str - device to execute the mlir_module on. - currently supports cpu, cuda, vulkan, and metal backends. - mlir_dialect: str - The dialect in which the given mlir_module is in. - Refer to {https://mlir.llvm.org/docs/Dialects/} - is_benchmark: bool - Whether this SharkInference module should be benchmark-enabled. - mmap: bool - Whether to load/run vmfb using mmap. It's `True` by default. - - Methods - ------- - __call__(function_name, inputs=None): - Runs the function with `function_name` within the mlir_module along - with the given inputs, if the inputs are not given it autogenerates the - inputs. Also, the inputs should be a numpy array. - input_info(): - Gives the information about the inputs required by the `function_name`. - This can be expensive as it does string matching to do so. - - """ - - def __init__( - self, - mlir_module, - device: str = "none", - mlir_dialect: str = "linalg", - is_benchmark: bool = False, - dispatch_benchmark: str = None, - dispatch_benchmark_dir: str = "temp_dispatch_benchmarks", - device_idx: int = None, - mmap: bool = True, - rt_flags: list = [], - ): - self.mlir_module = mlir_module - if mlir_module is not None: - if mlir_module and not os.path.isfile(mlir_module): - print( - "Warning: Initializing SharkInference with a mlir string/bytecode object will duplicate the model in RAM at compile time. To avoid this, initialize SharkInference with a path to a MLIR module on your hard disk instead." - ) - self.compile_str = True - else: - self.compile_str = False - self.device = shark_args.device if device == "none" else device - self.mlir_dialect = mlir_dialect - self.is_benchmark = is_benchmark - self.device_idx = device_idx - self.dispatch_benchmarks = ( - shark_args.dispatch_benchmarks - if dispatch_benchmark is None - else dispatch_benchmark - ) - self.dispatch_benchmarks_dir = ( - shark_args.dispatch_benchmarks_dir - if dispatch_benchmark_dir == "temp_dispatch_benchmarks" - else dispatch_benchmark_dir - ) - - self.shark_runner = None - self.mmap = mmap - self.rt_flags = rt_flags - - def compile(self, extra_args=[]): - if self.dispatch_benchmarks is not None: - extra_args.append( - f"--iree-hal-dump-executable-sources-to={self.dispatch_benchmarks_dir}" - ) - extra_args.append( - f"--iree-hal-dump-executable-binaries-to={self.dispatch_benchmarks_dir}" - ) - temp_dir = self.dispatch_benchmarks_dir.split("/") - temp_dir[-1] = "temp_" + temp_dir[-1] - temp_dir = "/".join(temp_dir) - self.temp_dispatch_benchmarks_dir = temp_dir - extra_args.append( - f"--iree-hal-dump-executable-benchmarks-to={self.temp_dispatch_benchmarks_dir}" - ) - - if self.is_benchmark == True: - from shark.shark_benchmark_runner import SharkBenchmarkRunner - - self.shark_runner = SharkBenchmarkRunner( - self.mlir_module, - self.device, - self.mlir_dialect, - extra_args=extra_args, - ) - - else: - self.shark_runner = SharkRunner( - self.mlir_module, - self.device, - self.mlir_dialect, - extra_args=extra_args, - device_idx=self.device_idx, - rt_flags=self.rt_flags, - ) - - if self.dispatch_benchmarks is not None: - create_dispatch_dirs(self.dispatch_benchmarks_dir, self.device) - compile_benchmark_dirs( - self.dispatch_benchmarks_dir, - self.device, - self.dispatch_benchmarks, - ) - os.system(f"rm -rf {self.temp_dispatch_benchmarks_dir}") - - # inputs are considered to be tuple of np.array. - def __call__(self, function_name: str, inputs: tuple, send_to_host=True): - return self.shark_runner.run( - function_name, inputs, send_to_host, device=self.device - ) - - # forward function. - def forward(self, inputs: tuple, send_to_host=True): - return self.shark_runner.run( - "forward", inputs, send_to_host, device=self.device - ) - - # Get all function names defined within the compiled module. - def get_functions_in_module(self): - return self.shark_runner.get_functions_in_module() - - # Captures the static input information from the mlir_module. - # TODO(pashu123): Generate the input information for dynamic shapes. - def _input_info(self, function_name): - # func_key to get the line which contains the function. - func_key = "func.func @" + function_name - func_header = None - for line in str(self.mlir_module).splitlines(): - if func_key in line: - func_header = line - break - if func_header is None: - print(f"Function: {function_name} not found") - - import re - - inputs = re.findall("\(.*?\)", func_header)[0].split(",") - shapes = [] - dtype = [] - for inp in inputs: - shape_dtype = re.findall(r"<[^>]*>", inp)[0].split("x") - shape_dtype[0], shape_dtype[-1] = ( - shape_dtype[0][1:], - shape_dtype[-1][:-1], - ) - shapes.append(tuple([int(x) for x in shape_dtype[:-1]])) - dtype.append(shape_dtype[-1]) - - return shapes, dtype - - # Generates random input to be feed into the graph. - def generate_random_inputs(self, low=0, high=1): - shapes, dtype = self._input_info() - inputs = [] - for i, j in zip(shapes, dtype): - inputs.append( - np.random.uniform(low, high, size=i).astype( - dtype_to_np_dtype[j] - ) - ) - return tuple(inputs) - - # TODO: Instead of passing directory and having names decided by the module - # , user may want to save the module with manual names. - def save_module( - self, dir=os.getcwd(), module_name=None, extra_args=[], debug=False - ): - return export_iree_module_to_vmfb( - self.mlir_module, - self.device, - dir, - self.mlir_dialect, - module_name=module_name, - extra_args=extra_args, - debug=debug, - compile_str=self.compile_str, - ) - - # load and return the module. - def load_module(self, path, extra_args=[]): - self.shark_runner = SharkRunner( - device=self.device, - compile_vmfb=False, - extra_args=extra_args, - rt_flags=self.rt_flags, - ) - params = load_flatbuffer( - path, - self.device, - self.device_idx, - mmap=self.mmap, - rt_flags=self.rt_flags, - ) - self.shark_runner.iree_compilation_module = params["vmfb"] - self.shark_runner.iree_config = params["config"] - self.shark_runner.temp_file_to_unlink = params["temp_file_to_unlink"] - del params - return diff --git a/shark/shark_runner.py b/shark/shark_runner.py deleted file mode 100644 index 9f24409b2f..0000000000 --- a/shark/shark_runner.py +++ /dev/null @@ -1,127 +0,0 @@ -# Copyright 2020 The Nod Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from shark.iree_utils.compile_utils import ( - get_iree_compiled_module, - get_results, - export_iree_module_to_vmfb, - load_flatbuffer, -) -from shark.iree_utils._common import check_device_drivers, device_driver_info -from shark.parser import shark_args -import os -import sys - - -# supported dialects by the shark-runtime. -supported_dialects = { - "linalg", - "auto", - "stablehlo", - "tosa", - "tf-lite", - "tm_tensor", -} - - -class SharkRunner: - """ - Base class for SharkInference and SharkTrainer - used to execute an mlir_module. - - ... - - Attributes - ---------- - mlir_module : str - mlir_module path, string, or bytecode. - device : str - device to execute the mlir_module on. - currently supports cpu, cuda, vulkan, and metal backends. - mlir_dialect: str - The dialect in which the given mlir_module is in. - Refer to {https://mlir.llvm.org/docs/Dialects/} - - Methods - ------- - run(function_name, inputs=None): - Runs the function with `function_name` within the mlir_module along - with the given inputs, if the inputs are not given it autogenerates the - inputs. Also, the inputs should be a numpy array. - input_info(): - Gives the information about the inputs required by the `function_name`. - This can be expensive as it does string matching to do so. - """ - - def __init__( - self, - mlir_module: bytes = None, - device: str = "none", - mlir_dialect: str = "linalg", - extra_args: list = [], - compile_vmfb: bool = True, - device_idx: int = None, - rt_flags: list = [], - ): - self.mlir_module = mlir_module - if self.mlir_module is not None: - if not os.path.isfile(mlir_module): - print( - "Warning: Initializing SharkRunner with a mlir string/bytecode object will duplicate the model in RAM at compile time. To avoid this, initialize SharkInference with a path to a MLIR module on your hard disk instead." - ) - self.compile_str = True - else: - self.compile_str = False - self.device = shark_args.device if device == "none" else device - self.mlir_dialect = mlir_dialect - self.extra_args = extra_args - self.device_idx = device_idx - self.rt_flags = rt_flags - - if check_device_drivers(self.device): - print(device_driver_info(self.device)) - sys.exit(1) - - if compile_vmfb == True: - # Compile the module to get the .vmfb. - params = get_iree_compiled_module( - self.mlir_module, - self.device, - self.mlir_dialect, - extra_args=self.extra_args, - device_idx=self.device_idx, - rt_flags=self.rt_flags, - compile_str=self.compile_str, - ) - self.iree_compilation_module = params["vmfb"] - self.iree_config = params["config"] - self.temp_file_to_unlink = params["temp_file_to_unlink"] - del params - - def run( - self, function_name, inputs: tuple, send_to_host=False, device=None - ): - return get_results( - self.iree_compilation_module, - function_name, - inputs, - self.iree_config, - self.mlir_dialect, - send_to_host, - device=device, - ) - - # Get all function names defined within the compiled module. - def get_functions_in_module(self): - return self.iree_compilation_module._vm_module.function_names diff --git a/shark/shark_trainer.py b/shark/shark_trainer.py deleted file mode 100644 index 16bdd984e9..0000000000 --- a/shark/shark_trainer.py +++ /dev/null @@ -1,163 +0,0 @@ -# Copyright 2020 The Nod Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from shark.parser import shark_args -from shark.shark_runner import SharkRunner -from shark.backward_makefx import MakeFxModule -from shark.shark_importer import import_with_fx, save_mlir -import numpy as np -from tqdm import tqdm -import sys - - -# Prints to stderr. -def print_err(*a): - print(*a, file=sys.stderr) - - -class SharkTrainer: - """Training pytorch, tensorflow module on shark runtime.""" - - def __init__( - self, - model, - input: tuple, - dynamic: bool = False, - device: str = None, - jit_trace: bool = False, - from_aot: bool = True, - ): - self.model = model - # Change tuple to list. - self.input = [x for x in input] - self.dynamic = dynamic - self.from_aot = from_aot - self.jit_trace = jit_trace - self.from_aot = from_aot - - # By default it's the torch frontend. - self.frontend = "pytorch" - self.device = device if device is not None else shark_args.device - - self.shark_runner = None - - # Sets the frontend i.e `pytorch` or `tensorflow`. - def set_frontend(self, frontend: str): - if frontend not in [ - "pytorch", - "torch", - "tensorflow", - "tf", - "stablehlo", - "mhlo", - "linalg", - "tosa", - ]: - print_err("frontend not supported.") - else: - self.frontend = frontend - - # Training function is needed in the case of torch_fn. - def compile(self, training_fn=None, mlir_type="linalg", extra_args=[]): - if self.frontend in ["torch", "pytorch"]: - packed_inputs = ( - dict(self.model.named_parameters()), - dict(self.model.named_buffers()), - tuple(self.input), - ) - mlir_module, func_name = import_with_fx( - training_fn, - packed_inputs, - False, - [], - training=True, - mlir_type=mlir_type, - ) - mlir_module = save_mlir( - mlir_module, - model_name="shark_model", - frontend="torch", - mlir_dialect=mlir_type, - ) - self.shark_runner = SharkRunner( - mlir_module, - self.device, - "tm_tensor", - extra_args=extra_args, - ) - elif self.frontend in ["tensorflow", "tf", "mhlo", "stablehlo"]: - self.shark_runner = SharkRunner( - self.model, - self.input, - self.dynamic, - self.device, - self.jit_trace, - self.from_aot, - self.frontend, - ) - else: - print_err("Unknown frontend") - return - - # The inputs to the mlir-graph are weights, buffers and inputs respectively. - def get_torch_params(self): - params = [i.detach() for i in self.model.parameters()] - buffers = [i.detach() for i in self.model.buffers()] - return params + buffers - - # Function to train pytorch module. - def _train_torch(self, num_iters): - """Returns the updated weights after num_iters""" - params = self.get_torch_params() - params = [x.numpy() for x in params] - print(f"Training started for {num_iters} iterations:") - for i in tqdm(range(num_iters)): - params = self.shark_runner.run( - "forward", params + self.input, self.frontend - ) - - return params - - # Function to train tensorflow module. - # Output final loss. - # TODO(raikonenfnu): Save updated weight/states in SHARK. - def _train_tf(self, num_iters): - input_list = [] - for x in self.input: - if isinstance(x, list): - nested_list = [] - for val in x: - if isinstance(val, np.ndarray): - nested_list.append(val) - else: - nested_list.append(val.numpy()) - input_list.append(nested_list) - elif isinstance(x, np.ndarray): - input_list.append(x) - else: - input_list.append(x.numpy()) - - print(f"Training started for {num_iters} iterations:") - for i in tqdm(range(num_iters)): - outputs = self.shark_runner.forward(input_list, self.frontend) - return outputs - - def train(self, num_iters=1): - if self.frontend in ["torch", "pytorch"]: - return self._train_torch(num_iters) - elif self.frontend in ["tf", "tensorflow", "mhlo"]: - return self._train_tf(num_iters) - else: - print_err("Unknown frontend") - return diff --git a/shark/stress_test.py b/shark/stress_test.py deleted file mode 100644 index 44dc9c429c..0000000000 --- a/shark/stress_test.py +++ /dev/null @@ -1,315 +0,0 @@ -# Copyright 2022 The Nod Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from iree.runtime import query_available_drivers, get_driver -from shark.shark_downloader import download_model -from shark.shark_inference import SharkInference -from typing import List, Optional, Tuple -import numpy as np -import argparse -from shark.iree_utils._common import _IREE_DEVICE_MAP -import multiprocessing -from shark.shark_runner import supported_dialects -import logging -from concurrent.futures import ProcessPoolExecutor -from concurrent.futures.thread import ThreadPoolExecutor -import time -import numpy as np - -IREE_TO_SHARK_DRIVER_MAP = {v: k for k, v in _IREE_DEVICE_MAP.items()} - - -def stress_test_compiled_model( - shark_module_path: str, - function_name: str, - device: str, - inputs: List[np.ndarray], - golden_out: List[np.ndarray], - batch_size: int, - max_iterations: int, - max_duration_seconds: float, - inference_timeout_seconds: float, - tolerance_nulp: int, - stress_test_index: int, -): - logging.info( - f"Running stress test {stress_test_index} on device {device}." - ) - # All interactions with the module must run in a single thread. - # We are using execution in a sperate thread in order to be able - # to wait with a timeout on the inference operation. - module_executor = ThreadPoolExecutor(1) - shark_module = module_executor.submit( - SharkInference, - mlir_module=bytes(), - function_name=function_name, - device=device, - ).result() - module_executor.submit( - shark_module.load_module, shark_module_path - ).result() - input_batches = [np.repeat(arr, batch_size, axis=0) for arr in inputs] - golden_output_batches = np.repeat(golden_out, batch_size, axis=0) - report_interval_seconds = 10 - start_time = time.time() - previous_report_time = start_time - first_iteration_output = None - for i in range(max_iterations): - output = module_executor.submit( - shark_module.forward, input_batches - ).result(inference_timeout_seconds) - if first_iteration_output is None: - np.testing.assert_array_almost_equal_nulp( - golden_output_batches, output, nulp=tolerance_nulp - ) - first_iteration_output = output - else: - np.testing.assert_array_equal(output, first_iteration_output) - current_time = time.time() - if report_interval_seconds < current_time - previous_report_time: - logging.info( - f"Stress test {stress_test_index} on device " - f"{device} at iteration {i+1}" - ) - previous_report_time = current_time - if max_duration_seconds < current_time - start_time: - return - logging.info(f"Stress test {stress_test_index} on device {device} done.") - - -def get_device_type(device_name: str): - return device_name.split("://", 1)[0] - - -def get_device_types(device_names: str): - return [get_device_type(device_name) for device_name in device_names] - - -def query_devices(device_types: Optional[List[str]] = None) -> List[str]: - devices = [] - if device_types is None: - device_types = [ - IREE_TO_SHARK_DRIVER_MAP[name] - for name in query_available_drivers() - if name in IREE_TO_SHARK_DRIVER_MAP - ] - for device_type in device_types: - driver = get_driver(_IREE_DEVICE_MAP[device_type]) - device_infos = driver.query_available_devices() - for device_info in device_infos: - uri_path = ( - device_info["path"] - if device_info["path"] != "" - else str(device_info["device_id"]) - ) - device_uri = f"{device_type}://{uri_path}" - devices.append(device_uri) - return devices - - -def compile_stress_test_module( - device_types: List[str], mlir_model: str, func_name: str, mlir_dialect: str -) -> List[str]: - shark_module_paths = [] - for device_type in device_types: - logging.info( - f"Compiling stress test model for device type {device_type}." - ) - shark_module = SharkInference( - mlir_model, - func_name, - mlir_dialect=mlir_dialect, - device=device_type, - ) - shark_module_paths.append(shark_module.save_module()) - return shark_module_paths - - -def stress_test( - model_name: str, - dynamic_model: bool = False, - device_types: Optional[List[str]] = None, - device_names: Optional[List[str]] = None, - batch_size: int = 1, - max_iterations: int = 10**7, - max_duration_seconds: float = 3600, - inference_timeout_seconds: float = 60, - mlir_dialect: str = "linalg", - frontend: str = "torch", - oversubscription_factor: int = 1, - tolerance_nulp: int = 50000, -): - logging.info(f"Downloading stress test model {model_name}.") - mlir_model, func_name, inputs, golden_out = download_model( - model_name=model_name, dynamic=dynamic_model, frontend=frontend - ) - - if device_names is None or device_types is not None: - device_names = [] if device_names is None else device_names - with ProcessPoolExecutor() as executor: - # query_devices needs to run in a separate process, - # because it will interfere with other processes that are forked later. - device_names.extend( - executor.submit(query_devices, device_types).result() - ) - - device_types_set = list(set(get_device_types(device_names))) - with ProcessPoolExecutor() as executor: - # This needs to run in a subprocess because when compiling for CUDA, - # some stuff get intialized and cuInit will fail in a forked process - # later. It should be just compiling, but alas. - shark_module_paths_set = executor.submit( - compile_stress_test_module, - device_types_set, - mlir_model, - func_name, - mlir_dialect, - ).result() - device_type_shark_module_path_map = { - device_type: module_path - for device_type, module_path in zip( - device_types_set, shark_module_paths_set - ) - } - device_name_shark_module_path_map = { - device_name: device_type_shark_module_path_map[ - get_device_type(device_name) - ] - for device_name in device_names - } - - # This needs to run in a spearate process, because it uses the drvier chache - # in IREE and a subsequent call to `iree.runtime.SystemContext.add_vm_module` - # in a forked process will hang. - with multiprocessing.Pool( - len(device_name_shark_module_path_map) * oversubscription_factor - ) as process_pool: - process_pool.starmap( - stress_test_compiled_model, - [ - ( - module_path, - func_name, - device_name, - inputs, - golden_out, - batch_size, - max_iterations, - max_duration_seconds, - inference_timeout_seconds, - tolerance_nulp, - stress_test_index, - ) - for stress_test_index, (device_name, module_path) in enumerate( - list(device_name_shark_module_path_map.items()) - * oversubscription_factor - ) - ], - ) - - -if __name__ == "__main__": - logging.basicConfig(encoding="utf-8", level=logging.INFO) - parser = argparse.ArgumentParser( - description="Downloads, compiles and runs a model from the tank to stress test the system." - ) - parser.add_argument( - "--model", type=str, help="Model name in the tank.", default="alexnet" - ) - parser.add_argument( - "--dynamic", - help="Use dynamic version of the model.", - action="store_true", - default=False, - ) - parser.add_argument( - "--frontend", type=str, help="Frontend of the model.", default="torch" - ) - parser.add_argument( - "--mlir-dialect", - type=str, - help="MLIR dialect of the model.", - default="linalg", - choices=supported_dialects, - ) - parser.add_argument( - "--device-types", - type=str, - nargs="*", - choices=_IREE_DEVICE_MAP.keys(), - help="Runs the stress test on all devices with that type. " - "If absent and no deveices are specified " - "will run against all available devices.", - ) - parser.add_argument( - "--devices", - type=str, - nargs="*", - help="List of devices to run the stress test on. " - "If device-types is specified will run against the union of the two.", - ) - parser.add_argument( - "--batch-size", - type=int, - help="Number of inputs to feed into the model", - default=1, - ) - parser.add_argument( - "--oversubscription", - type=int, - help="Oversubscrption factor. Each device will execute the model simultaneously " - "this many number of times.", - default=1, - ) - parser.add_argument( - "--max-iterations", - type=int, - help="Maximum number of iterations to run the stress test per device.", - default=10**7, - ) - parser.add_argument( - "--max-duration", - type=float, - help="Maximum number of seconds to run the stress test.", - default=3600, - ) - parser.add_argument( - "--inference-timeout", - type=float, - help="Timeout in seconds for a single model inference operation.", - default=60, - ) - parser.add_argument( - "--tolerance-nulp", - type=int, - help="The maximum number of unit in the last place for tolerance " - "when verifing results with the golden reference output.", - default=50000, - ) - - args = parser.parse_known_args()[0] - stress_test( - model_name=args.model, - dynamic_model=args.dynamic, - frontend=args.frontend, - mlir_dialect=args.mlir_dialect, - device_types=args.device_types, - device_names=args.devices, - batch_size=args.batch_size, - oversubscription_factor=args.oversubscription, - max_iterations=args.max_iterations, - max_duration_seconds=args.max_duration, - inference_timeout_seconds=args.inference_timeout, - tolerance_nulp=args.tolerance_nulp, - ) diff --git a/shark/tests/test_shark_importer.py b/shark/tests/test_shark_importer.py deleted file mode 100644 index 801a7b453a..0000000000 --- a/shark/tests/test_shark_importer.py +++ /dev/null @@ -1,144 +0,0 @@ -# RUN: %PYTHON %s -import numpy as np -from shark.shark_importer import SharkImporter -import pytest -from shark.parser import shark_args -from shark.shark_inference import SharkInference -from shark.tflite_utils import TFLitePreprocessor -import sys - -# model_path = "https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1?lite-format=tflite" - - -# Inputs modified to be useful albert inputs. -def generate_inputs(input_details): - for input in input_details: - print(str(input["shape"]), input["dtype"].__name__) - - args = [] - args.append( - np.random.randint( - low=0, - high=256, - size=input_details[0]["shape"], - dtype=input_details[0]["dtype"], - ) - ) - args.append( - np.ones( - shape=input_details[1]["shape"], dtype=input_details[1]["dtype"] - ) - ) - args.append( - np.zeros( - shape=input_details[2]["shape"], dtype=input_details[2]["dtype"] - ) - ) - return args - - -def compare_results(mlir_results, tflite_results, details): - print("Compare mlir_results VS tflite_results: ") - assert len(mlir_results) == len( - tflite_results - ), "Number of results do not match" - for i in range(len(details)): - mlir_result = mlir_results[i] - tflite_result = tflite_results[i] - mlir_result = mlir_result.astype(np.single) - tflite_result = tflite_result.astype(np.single) - assert mlir_result.shape == tflite_result.shape, "shape doesnot match" - max_error = np.max(np.abs(mlir_result - tflite_result)) - print("Max error (%d): %f", i, max_error) - - -class AlbertTfliteModuleTester: - def __init__( - self, - dynamic=False, - device="cpu", - save_mlir=False, - save_vmfb=False, - ): - self.dynamic = dynamic - self.device = device - self.save_mlir = save_mlir - self.save_vmfb = save_vmfb - - def create_and_check_module(self): - shark_args.save_mlir = self.save_mlir - shark_args.save_vmfb = self.save_vmfb - tflite_preprocessor = TFLitePreprocessor(model_name="albert_lite_base") - - raw_model_file_path = tflite_preprocessor.get_raw_model_file() - inputs = tflite_preprocessor.get_inputs() - tflite_interpreter = tflite_preprocessor.get_interpreter() - - my_shark_importer = SharkImporter( - module=tflite_interpreter, - inputs=inputs, - frontend="tflite", - raw_model_file=raw_model_file_path, - ) - mlir_model, func_name = my_shark_importer.import_mlir() - - shark_module = SharkInference( - mlir_module=mlir_model, - function_name=func_name, - device=self.device, - mlir_dialect="tflite", - ) - - # Case1: Use shark_importer default generate inputs - shark_module.compile() - mlir_results = shark_module.forward(inputs) - ## post process results for compare - input_details, output_details = tflite_preprocessor.get_model_details() - mlir_results = list(mlir_results) - for i in range(len(output_details)): - dtype = output_details[i]["dtype"] - mlir_results[i] = mlir_results[i].astype(dtype) - tflite_results = tflite_preprocessor.get_golden_output() - compare_results(mlir_results, tflite_results, output_details) - - # Case2: Use manually set inputs - input_details, output_details = tflite_preprocessor.get_model_details() - inputs = generate_inputs(input_details) # new inputs - - shark_module = SharkInference( - mlir_module=mlir_model, - function_name=func_name, - device=self.device, - mlir_dialect="tflite", - ) - shark_module.compile() - mlir_results = shark_module.forward(inputs) - ## post process results for compare - tflite_results = tflite_preprocessor.get_golden_output() - compare_results(mlir_results, tflite_results, output_details) - # print(mlir_results) - - -# A specific case can be run by commenting different cases. Runs all the test -# across cpu, gpu and vulkan according to available drivers. -pytest_param = pytest.mark.parametrize( - ("dynamic", "device"), - [ - pytest.param(False, "cpu"), - # TODO: Language models are failing for dynamic case.. - pytest.param(True, "cpu", marks=pytest.mark.skip), - ], -) - - -@pytest_param -@pytest.mark.xfail( - sys.platform == "darwin", reason="known macos tflite install issue" -) -def test_albert(dynamic, device): - module_tester = AlbertTfliteModuleTester(dynamic=dynamic, device=device) - module_tester.create_and_check_module() - - -if __name__ == "__main__": - test_albert(False, "cpu") diff --git a/shark/tests/test_stress_test.py b/shark/tests/test_stress_test.py deleted file mode 100644 index 1474da124a..0000000000 --- a/shark/tests/test_stress_test.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2022 The Nod Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest -import subprocess -import sys -import importlib.util - - -def test_stress_test(): - subprocess.check_call( - [ - sys.executable, - importlib.util.find_spec("shark.stress_test").origin, - "--model=squeezenet1_0", - "--devices", - "cpu", - "--max-iterations=1", - ] - ) diff --git a/shark/tests/test_txt2img_ui.py b/shark/tests/test_txt2img_ui.py deleted file mode 100644 index 278ef2cd3b..0000000000 --- a/shark/tests/test_txt2img_ui.py +++ /dev/null @@ -1,62 +0,0 @@ -import unittest -from unittest.mock import mock_open, patch - -from apps.stable_diffusion.web.ui.txt2img_ui import ( - export_settings, - load_settings, - all_gradio_labels, -) - - -class TestExportSettings(unittest.TestCase): - @patch("builtins.open", new_callable=mock_open) - @patch("json.dump") - def test_export_settings(self, mock_json_dump, mock_file): - test_values = ["value1", "value2", "value3"] - expected_output = { - "txt2img": { - label: value - for label, value in zip(all_gradio_labels, test_values) - } - } - - export_settings(*test_values) - mock_file.assert_called_once_with("./ui/settings.json", "w") - mock_json_dump.assert_called_once_with( - expected_output, mock_file(), indent=4 - ) - - @patch("apps.stable_diffusion.web.ui.txt2img_ui.json.load") - @patch( - "builtins.open", - new_callable=mock_open, - read_data='{"txt2img": {"some_setting": "some_value"}}', - ) - def test_load_settings_file_exists(self, mock_file, mock_json_load): - mock_json_load.return_value = { - "txt2img": { - "txt2img_custom_model": "custom_model_value", - "custom_vae": "custom_vae_value", - } - } - - settings = load_settings() - self.assertEqual(settings[0], "custom_model_value") - self.assertEqual(settings[1], "custom_vae_value") - - @patch("apps.stable_diffusion.web.ui.txt2img_ui.json.load") - @patch("builtins.open", side_effect=FileNotFoundError) - def test_load_settings_file_not_found(self, mock_file, mock_json_load): - settings = load_settings() - - default_lora_weights = "None" - self.assertEqual(settings[4], default_lora_weights) - - @patch("apps.stable_diffusion.web.ui.txt2img_ui.json.load") - @patch("builtins.open", new_callable=mock_open, read_data="{}") - def test_load_settings_key_error(self, mock_file, mock_json_load): - mock_json_load.return_value = {} - - settings = load_settings() - default_lora_weights = "None" - self.assertEqual(settings[4], default_lora_weights) diff --git a/shark/tflite_utils.py b/shark/tflite_utils.py deleted file mode 100644 index b6755aa55d..0000000000 --- a/shark/tflite_utils.py +++ /dev/null @@ -1,208 +0,0 @@ -import tensorflow as tf -import numpy as np -import os -import csv -import urllib.request - - -class TFLiteModelUtil: - def __init__(self, raw_model_file): - self.raw_model_file = str(raw_model_file) - self.tflite_interpreter = None - self.input_details = None - self.output_details = None - self.inputs = [] - - def setup_tflite_interpreter(self): - self.tflite_interpreter = tf.lite.Interpreter( - model_path=self.raw_model_file - ) - self.tflite_interpreter.allocate_tensors() - # default input initialization - return self.get_model_details() - - def get_model_details(self): - print("Get tflite input output details") - self.input_details = self.tflite_interpreter.get_input_details() - self.output_details = self.tflite_interpreter.get_output_details() - return self.input_details, self.output_details - - def invoke_tflite(self, inputs): - self.inputs = inputs - print("invoke_tflite") - for i, input in enumerate(self.inputs): - self.tflite_interpreter.set_tensor( - self.input_details[i]["index"], input - ) - self.tflite_interpreter.invoke() - - # post process tflite_result for compare with mlir_result, - # for tflite the output is a list of numpy.tensor - tflite_results = [] - for output_detail in self.output_details: - tflite_results.append( - np.array( - self.tflite_interpreter.get_tensor(output_detail["index"]) - ) - ) - - for i in range(len(self.output_details)): - # print("output_details ", i, "shape", self.output_details[i]["shape"].__name__, - # ", dtype: ", self.output_details[i]["dtype"].__name__) - out_dtype = self.output_details[i]["dtype"] - tflite_results[i] = tflite_results[i].astype(out_dtype) - return tflite_results - - -class TFLitePreprocessor: - def __init__( - self, - model_name, - input_details=None, - output_details=None, - model_path=None, - ): - self.model_name = model_name - self.input_details = ( - input_details # used for tflite, optional for tf/pytorch - ) - self.output_details = ( - output_details # used for tflite, optional for tf/pytorch - ) - self.inputs = [] - self.model_path = model_path # url to download the model - self.raw_model_file = ( - None # local address for raw tf/tflite/pytorch model - ) - self.mlir_file = ( - None # local address for .mlir file of tf/tflite/pytorch model - ) - self.mlir_model = None # read of .mlir file - self.output_tensor = ( - None # the raw tf/pytorch/tflite_output_tensor, not mlir_tensor - ) - self.interpreter = ( - None # could be tflite/tf/torch_interpreter in utils - ) - self.input_file = None - self.output_file = None - - # create tmp model file directory - if self.model_path is None and self.model_name is None: - print( - "Error. No model_path, No model name,Please input either one." - ) - return - - print("Setting up for TMP_WORK_DIR") - self.workdir = os.path.join( - os.path.dirname(__file__), "./../gen_shark_tank" - ) - os.makedirs(self.workdir, exist_ok=True) - print(f"TMP_WORK_DIR = {self.workdir}") - - # compile and run tfhub tflite - load_model_success = self.load_tflite_model() - if not load_model_success: - print("Error, load tflite model fail") - return - - if (self.input_details is None) or (self.output_details is None): - # print("Setting up tflite interpreter to get model input details") - self.setup_interpreter() - - inputs = self.generate_inputs(self.input_details) # device_inputs - self.setup_inputs(inputs) - - def load_tflite_model(self): - # use model name get dir. - tflite_model_name_dir = os.path.join( - self.workdir, str(self.model_name) - ) - - os.makedirs(tflite_model_name_dir, exist_ok=True) - print(f"TMP_TFLITE_MODELNAME_DIR = {tflite_model_name_dir}") - - self.raw_model_file = "/".join( - [tflite_model_name_dir, str(self.model_name) + "_tflite.tflite"] - ) - self.mlir_file = "/".join( - [tflite_model_name_dir, str(self.model_name) + "_tflite.mlir"] - ) - self.input_file = "/".join([tflite_model_name_dir, "inputs"]) - self.output_file = "/".join([tflite_model_name_dir, "golden_out"]) - # np.save("/".join([tflite_model_name_dir, "function_name"]), np.array("main")) - - if os.path.exists(self.raw_model_file): - print( - "Local address for .tflite model file Exists: ", - self.raw_model_file, - ) - else: - print("No local tflite file, Download tflite model") - if self.model_path is None: - # get model file from tflite_model_list.csv or download from gs://bucket - print("No model_path, get from tflite_model_list.csv") - tflite_model_list_path = os.path.join( - os.path.dirname(__file__), - "../tank/tflite/tflite_model_list.csv", - ) - tflite_model_list = csv.reader(open(tflite_model_list_path)) - for row in tflite_model_list: - if str(row[0]) == str(self.model_name): - self.model_path = row[1] - print("tflite_model_name", str(row[0])) - print("tflite_model_link", self.model_path) - if self.model_path is None: - print("Error, No model path find in tflite_model_list.csv") - return False - urllib.request.urlretrieve(self.model_path, self.raw_model_file) - return True - - def setup_interpreter(self): - self.interpreter = TFLiteModelUtil(self.raw_model_file) - ( - self.input_details, - self.output_details, - ) = self.interpreter.setup_tflite_interpreter() - - def generate_inputs(self, input_details): - self.inputs = [] - for tmp_input in input_details: - print( - "input_details shape:", - str(tmp_input["shape"]), - " type:", - tmp_input["dtype"].__name__, - ) - self.inputs.append( - np.ones(shape=tmp_input["shape"], dtype=tmp_input["dtype"]) - ) - return self.inputs - - def setup_inputs(self, inputs): - # print("Setting up inputs") - self.inputs = inputs - - def get_mlir_model(self): - return self.mlir_model - - def get_mlir_file(self): - return self.mlir_file - - def get_inputs(self): - return self.inputs - - def get_golden_output(self): - self.output_tensor = self.interpreter.invoke_tflite(self.inputs) - np.savez(self.output_file, *self.output_tensor) - return self.output_tensor - - def get_model_details(self): - return self.input_details, self.output_details - - def get_raw_model_file(self): - return self.raw_model_file - - def get_interpreter(self): - return self.interpreter diff --git a/shark/torch_mlir_lockstep_tensor.py b/shark/torch_mlir_lockstep_tensor.py deleted file mode 100644 index bfcaafed70..0000000000 --- a/shark/torch_mlir_lockstep_tensor.py +++ /dev/null @@ -1,220 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Also available under a BSD-style license. See LICENSE. -import contextlib -import re -import traceback -import warnings -from typing import Any -import numpy as np - -import torch -from torch.utils._pytree import tree_map - -from torch_mlir.eager_mode.ir_building import build_mlir_module -from torch_mlir.eager_mode.torch_mlir_dispatch import ( - UnsupportedByTorchMlirEagerMode, - normalize_args_kwargs, - check_get_aliased_arg, -) -from torch_mlir.eager_mode import EAGER_MODE_DEBUG -from torch_mlir.eager_mode.torch_mlir_tensor import ( - TorchMLIRTensor, - check_requires_grad, - make_wrapper_subclass_from_torch_tensor, - make_bare_wrapper_subclass, - UNSUPPORTED_OPS, - no_dispatch, -) -from torch_mlir.eager_mode import torch_mlir_tensor -from shark.iree_eager_backend import EagerModeIREELinalgOnTensorsBackend - - -backend = EagerModeIREELinalgOnTensorsBackend("cpu") -torch_mlir_tensor.backend = backend -rtol = 1e-04 -atol = 1e-05 - - -class TorchMLIRLockstepTensor(TorchMLIRTensor): - """This class overrides the dispatching for TorchMLIRTensor to allow for an op-by-op numerical comparison between PyTorch and the Torch-MLIR -> IREE backend compilation pipeline. This only supports the IREE backend and focuses on op-by-op level verification. - - TODO: Extend this to do a cumulative trace with summary statistics at the end. Possibly requires a wrapper environment to store full trace info. - """ - - def __new__(cls, elem, **kwargs): - if kwargs.get("constructing_from_device_tensor", False): - tensor_meta_data = backend.get_torch_metadata(elem, kwargs) - r = make_bare_wrapper_subclass( - cls=cls, - size=tensor_meta_data.size, - strides=tensor_meta_data.strides, - storage_offset=tensor_meta_data.storage_offset, - dtype=tensor_meta_data.dtype, - layout=tensor_meta_data.layout, - device=tensor_meta_data.device, - requires_grad=tensor_meta_data.requires_grad, - ) - r.elem = elem - elif isinstance(elem, torch.nn.Parameter): - r = make_wrapper_subclass_from_torch_tensor( - cls, elem.data, **kwargs - ) - # This is a hack to handle non-contiguous data through IREE-backend - nt = elem.detach().data.numpy() - if not nt.flags["C_CONTIGUOUS"]: - nt = np.ascontiguousarray(nt, dtype=nt.dtype) - r.elem = backend.transfer_from_torch_to_device( - torch.from_numpy(nt) - ) - elif isinstance(elem, torch.Tensor): - r = make_wrapper_subclass_from_torch_tensor(cls, elem, **kwargs) - # Ditto TODO: Find a better way to handle this - nt = elem.numpy() - if not nt.flags["C_CONTIGUOUS"]: - nt = np.ascontiguousarray(nt, dtype=nt.dtype) - r.elem = backend.transfer_from_torch_to_device( - torch.from_numpy(nt) - ) - # This branch handles the case when a python scalar is passed to some op - # or is returned from some aten op, such as _local_scalar_dense. - elif isinstance(elem, (int, float, bool)): - return elem - else: - raise ValueError(f"Unknown element type: {type(elem)}") - return r - - def __repr__(self): - if self.grad_fn: - return f"TorchMLIRLockstepTensor({self.elem}, backend={backend.__class__.__name__}, grad_fn={self.grad_fn})" - else: - return f"TorchMLIRLockstepTensor({self.elem}, backend={backend.__class__.__name__})" - - """This does essentially the same dispatch as TorchMLIRTensor but operates as if debug mode is enabled. The numeric verification happens after the Torch-MLIR result is obtained by comparing against the - """ - - @classmethod - def __torch_dispatch__(cls, func, _types, args=(), kwargs=None): - requires_grad = check_requires_grad(*args, **kwargs) - try: - with no_dispatch(): - if hasattr(func, "op_name"): - op_name = func.op_name - elif hasattr(func, "__name__"): - # Handle builtin_function_or_method. - op_name = func.__name__ - else: - raise RuntimeError(f"op {func} has no name") - - if UNSUPPORTED_OPS.match(op_name): - raise UnsupportedByTorchMlirEagerMode(op_name) - - if not hasattr(func, "_schema"): - raise RuntimeError(f"op {func} has no schema.") - - normalized_kwargs = normalize_args_kwargs(func, args, kwargs) - - if "layout" in normalized_kwargs and normalized_kwargs[ - "layout" - ] not in {0, None}: - raise UnsupportedByTorchMlirEagerMode( - f"{normalized_kwargs['layout']} layout not supported." - ) - if "memory_format" in normalized_kwargs and normalized_kwargs[ - "memory_format" - ] not in {0, None}: - raise UnsupportedByTorchMlirEagerMode( - f"{normalized_kwargs['memory_format']} memory format not supported." - ) - eager_module = build_mlir_module(func, normalized_kwargs) - device_tensor_args = [ - kwarg.elem - for _, kwarg in normalized_kwargs.items() - if isinstance(kwarg, cls) - ] - assert len(eager_module.body.operations[0].arguments) == len( - device_tensor_args - ), "Number of parameters and number of arguments differs." - op_mlir_backend_callable = backend.compile(eager_module) - out = op_mlir_backend_callable(*device_tensor_args) - out = tree_map( - lambda x: cls( - x, - requires_grad=requires_grad, - constructing_from_device_tensor=True, - ), - out, - ) - - # Numeric verification; Value for comparison comes from PyTorch eager - with no_dispatch(): - unwrapped_args = tree_map(cls.unwrap, args) - unwrapped_kwargs = tree_map(cls.unwrap, kwargs) - if "_reshape_alias" in op_name: - native_out = torch.ops.aten.view( - unwrapped_args[0], unwrapped_args[1] - ) - else: - native_out = func(*unwrapped_args, **unwrapped_kwargs) - - native_out = tree_map( - lambda x: cls(x, requires_grad=requires_grad), native_out - ).elem - tmp_out = out.elem - - try: - np.testing.assert_allclose( - native_out.to_host(), - tmp_out.to_host(), - rtol=rtol, - atol=atol, - ) - except Exception as e: - shaped_args = [ - arg.shape if torch.is_tensor(arg) else arg - for arg in unwrapped_args - ] - shaped_kwargs = [ - kwarg.shape if torch.is_tensor(kwarg) else kwarg - for kwarg in unwrapped_kwargs - ] - warnings.warn( - f"Lockstep accuracy verification failed with error: *{str(e)}*; " - f"Dispatched function name: *{str(func)}*; " - f"Dispatched function args: *{str(shaped_args)}*; " - f"Dispatched function kwargs: *{str(shaped_kwargs)}*; " - ) - except Exception as e: - warnings.warn(traceback.format_exc()) - if isinstance(e, UnsupportedByTorchMlirEagerMode): - warnings.warn( - f"Couldn't use TorchMLIR eager because current incompatibility: *{str(e)}*; running through PyTorch eager." - ) - else: - warnings.warn( - f"Couldn't use TorchMLIR eager because of error: *{str(e)}*; " - f"Running through PyTorch eager" - ) - - with no_dispatch(): - unwrapped_args = tree_map(cls.unwrap, args) - unwrapped_kwargs = tree_map(cls.unwrap, kwargs) - if "_reshape_alias" in op_name: - out = torch.ops.aten.view( - unwrapped_args[0], unwrapped_args[1] - ) - else: - out = func(*unwrapped_args, **unwrapped_kwargs) - - out = tree_map(lambda x: cls(x, requires_grad=requires_grad), out) - - maybe_aliased_arg_name = check_get_aliased_arg(func) - if maybe_aliased_arg_name is not None: - warnings.warn( - f"Found aliased arg, but didn't copy tensor contents. This could lead to incorrect results for E2E model execution but doesn't affect the validity of the lockstep op verification." - ) - # TODO: Find a way to handle argument aliasing for IREE backend - # backend.copy_into(normalized_kwargs[maybe_aliased_arg_name].elem, out.elem) - - return out diff --git a/shark/torch_mlir_utils.py b/shark/torch_mlir_utils.py deleted file mode 100644 index 85593d5402..0000000000 --- a/shark/torch_mlir_utils.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright 2020 The Nod Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from torch_mlir.ir import StringAttr -import torch_mlir -from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend -import tempfile -from shark.parser import shark_args -import io - -mlir_type_mapping_dict = { - "linalg": torch_mlir.OutputType.LINALG_ON_TENSORS, - "stablehlo": torch_mlir.OutputType.STABLEHLO, - "tosa": torch_mlir.OutputType.TOSA, -} - - -def get_module_name_for_asm_dump(module): - """Gets a name suitable for an assembly dump. - The name is not guaranteed to be unique. - """ - if not "torch.debug_module_name" in module.operation.attributes: - return "UnnammedModule" - return StringAttr( - module.operation.attributes["torch.debug_module_name"] - ).value - - -def run_on_refbackend(torch_module, inputs): - backend = refbackend.RefBackendLinalgOnTensorsBackend() - compiled = backend.compile(torch_module) - jit_module = backend.load(compiled) - np_inputs = [x.numpy() for x in inputs] - return jit_module.forward(np_inputs[0]) - - -# Creates dynamic dims for all dims. -# TODO: Pass user specified dynamic dims. -def create_dynamic_placeholders(inputs): - placeholders = [] - for inp in inputs: - placeholder = torch_mlir.TensorPlaceholder.like( - inp, dynamic_axes=[i for i in range(len(inp.shape))] - ) - placeholders.append(placeholder) - return tuple(placeholders) - - -def get_torch_mlir_module( - module, - input: tuple, - dynamic: bool, - jit_trace: bool, - return_str: bool = False, - mlir_type: str = "linalg", -): - """Get the MLIR's linalg-on-tensors module from the torchscipt module.""" - ignore_traced_shapes = False - if dynamic: - input = create_dynamic_placeholders(input) - if jit_trace: - ignore_traced_shapes = True - - tempfile.tempdir = "." - - mlir_module = torch_mlir.compile( - module, - input, - output_type=mlir_type_mapping_dict[mlir_type], - use_tracing=jit_trace, - ignore_traced_shapes=ignore_traced_shapes, - ) - - if return_str: - return mlir_module.operation.get_asm() - bytecode_stream = io.BytesIO() - mlir_module.operation.write_bytecode(bytecode_stream) - bytecode = bytecode_stream.getvalue() - return bytecode