Skip to content

Commit

Permalink
Purge shark/ directory, minimal ireert api usage for dynamically load…
Browse files Browse the repository at this point in the history
…ed plugins
  • Loading branch information
eagarvey-amd committed Jun 4, 2024
1 parent dac7a29 commit 4aa2d8b
Show file tree
Hide file tree
Showing 84 changed files with 75 additions and 14,684 deletions.
198 changes: 72 additions & 126 deletions apps/shark_studio/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,60 @@
from cpuinfo import get_cpu_info


def iree_device_map(device):
uri_parts = device.split("://", 2)
iree_driver = (
_IREE_DEVICE_MAP[uri_parts[0]]
if uri_parts[0] in _IREE_DEVICE_MAP
else uri_parts[0]
)
if len(uri_parts) == 1:
return iree_driver
elif "rocm" in uri_parts:
return "rocm"
else:
return f"{iree_driver}://{uri_parts[1]}"


def get_supported_device_list():
return list(_IREE_DEVICE_MAP.keys())


_IREE_DEVICE_MAP = {
"cpu": "local-task",
"cpu-task": "local-task",
"cpu-sync": "local-sync",
"cuda": "cuda",
"vulkan": "vulkan",
"metal": "metal",
"rocm": "rocm",
"hip": "hip",
"intel-gpu": "level_zero",
}


def iree_target_map(device):
if "://" in device:
device = device.split("://")[0]
return _IREE_TARGET_MAP[device] if device in _IREE_TARGET_MAP else device


_IREE_TARGET_MAP = {
"cpu": "llvm-cpu",
"cpu-task": "llvm-cpu",
"cpu-sync": "llvm-cpu",
"cuda": "cuda",
"vulkan": "vulkan-spirv",
"metal": "metal",
"rocm": "rocm",
"hip": "rocm",
"intel-gpu": "opencl-spirv",
}



def get_available_devices():
return ["AMD Radeon 780M => rocm"]
def get_devices_by_name(driver_name):
from shark.iree_utils._common import iree_device_map

device_list = []
try:
Expand Down Expand Up @@ -91,13 +139,29 @@ def get_devices_by_name(driver_name):
break
return available_devices

def clean_device_info(raw_device):
# return appropriate device and device_id for consumption by Studio pipeline
# Multiple devices only supported for vulkan and rocm (as of now).
# default device must be selected for all others

def parse_device(device_str, target_override=""):
from shark.iree_utils.compile_utils import (
clean_device_info,
get_iree_target_triple,
iree_target_map,
device_id = None
device = (
raw_device
if "=>" not in raw_device
else raw_device.split("=>")[1].strip()
)
if "://" in device:
device, device_id = device.split("://")
if len(device_id) <= 2:
device_id = int(device_id)

if device not in ["hip", "rocm", "vulkan"]:
device_id = None
if device in ["hip", "rocm", "vulkan"] and device_id == None:
device_id = 0
return device, device_id

def parse_device(device_str, target_override=""):

rt_driver, device_id = clean_device_info(device_str)
target_backend = iree_target_map(rt_driver)
Expand Down Expand Up @@ -144,9 +208,6 @@ def get_rocm_target_chip(device_str):
if key in device_str:
return rocm_chip_map[key]
return None
# raise AssertionError(
# f"Device {device_str} not recognized. Please file an issue at https://github.com/nod-ai/SHARK/issues."
# )


def get_all_devices(driver_name):
Expand Down Expand Up @@ -179,7 +240,6 @@ def get_device_mapping(driver, key_combination=3):
dict: map to possible device names user can input mapped to desired
combination of name/path.
"""
from shark.iree_utils._common import iree_device_map

driver = iree_device_map(driver)
device_list = get_all_devices(driver)
Expand Down Expand Up @@ -226,118 +286,4 @@ def get_opt_flags(model, precision="fp16"):
# Due to lack of support for multi-reduce, we always collapse reduction
# dims before dispatch formation right now.
iree_flags += ["--iree-flow-collapse-reduction-dims"]
return iree_flags


# def map_device_to_name_path(device, key_combination=3):
# """Gives the appropriate device data (supported name/path) for user
# selected execution device
# Args:
# device (str): user
# key_combination (int, optional): choice for mapping value for
# device name.
# 1 : path
# 2 : name
# 3 : (name, path)
# Defaults to 3.
# Raises:
# ValueError:
# Returns:
# str / tuple: returns the mapping str or tuple of mapping str for
# the device depending on key_combination value
# """
# driver = device.split("://")[0]
# device_map = get_device_mapping(driver, key_combination)
# try:
# device_mapping = device_map[device]
# except KeyError:
# raise ValueError(f"Device '{device}' is not a valid device.")
# return device_mapping

# def get_devices_by_name(driver_name):
# from shark.iree_utils._common import iree_device_map

# device_list = []
# try:
# driver_name = iree_device_map(driver_name)
# device_list_dict = get_all_devices(driver_name)
# print(f"{driver_name} devices are available.")
# except:
# print(f"{driver_name} devices are not available.")
# else:
# cpu_name = get_cpu_info()["brand_raw"]
# for i, device in enumerate(device_list_dict):
# device_name = (
# cpu_name if device["name"] == "default" else device["name"]
# )
# if "local" in driver_name:
# device_list.append(
# f"{device_name} => {driver_name.replace('local', 'cpu')}"
# )
# else:
# # for drivers with single devices
# # let the default device be selected without any indexing
# if len(device_list_dict) == 1:
# device_list.append(f"{device_name} => {driver_name}")
# else:
# device_list.append(f"{device_name} => {driver_name}://{i}")
# return device_list

# set_iree_runtime_flags()

# available_devices = []
# from shark.iree_utils.vulkan_utils import (
# get_all_vulkan_devices,
# )

# vulkaninfo_list = get_all_vulkan_devices()
# vulkan_devices = []
# id = 0
# for device in vulkaninfo_list:
# vulkan_devices.append(f"{device.strip()} => vulkan://{id}")
# id += 1
# if id != 0:
# print(f"vulkan devices are available.")
# available_devices.extend(vulkan_devices)
# metal_devices = get_devices_by_name("metal")
# available_devices.extend(metal_devices)
# cuda_devices = get_devices_by_name("cuda")
# available_devices.extend(cuda_devices)
# rocm_devices = get_devices_by_name("rocm")
# available_devices.extend(rocm_devices)
# cpu_device = get_devices_by_name("cpu-sync")
# available_devices.extend(cpu_device)
# cpu_device = get_devices_by_name("cpu-task")
# available_devices.extend(cpu_device)
# return available_devices


# # Generate and return a new seed if the provided one is not in the
# # supported range (including -1)
# def sanitize_seed(seed: int | str):
# seed = int(seed)
# uint32_info = np.iinfo(np.uint32)
# uint32_min, uint32_max = uint32_info.min, uint32_info.max
# if seed < uint32_min or seed >= uint32_max:
# seed = randint(uint32_min, uint32_max)
# return seed


# # take a seed expression in an input format and convert it to
# # a list of integers, where possible
# def parse_seed_input(seed_input: str | list | int):
# if isinstance(seed_input, str):
# try:
# seed_input = json.loads(seed_input)
# except (ValueError, TypeError):
# seed_input = None

# if isinstance(seed_input, int):
# return [seed_input]

# if isinstance(seed_input, list) and all(type(seed) is int for seed in seed_input):
# return seed_input

# raise TypeError(
# "Seed input must be an integer or an array of integers in JSON format"
# )
return iree_flags
6 changes: 3 additions & 3 deletions apps/shark_studio/web/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def webui():
launch_api = cmd_opts.api
initialize.initialize()

from ui.chat import chat_element
#from ui.chat import chat_element
from ui.sd import sd_element
from ui.outputgallery import outputgallery_element

Expand Down Expand Up @@ -194,8 +194,8 @@ def register_outputgallery_button(button, selectedid, inputs, outputs):
sd_element.render()
with gr.TabItem(label="Output Gallery", id=1):
outputgallery_element.render()
with gr.TabItem(label="Chat Bot", id=2, visible=False):
chat_element.render()
# with gr.TabItem(label="Chat Bot", id=2):
# chat_element.render()

studio_web.queue()

Expand Down
28 changes: 0 additions & 28 deletions shark/__init__.py

This file was deleted.

78 changes: 0 additions & 78 deletions shark/backward_makefx.py

This file was deleted.

Empty file removed shark/dynamo_backend/__init__.py
Empty file.
Loading

0 comments on commit 4aa2d8b

Please sign in to comment.