Skip to content

Commit

Permalink
Purge unused code and patch out iree runtime handling from init
Browse files Browse the repository at this point in the history
  • Loading branch information
eagarvey-amd committed Jun 3, 2024
1 parent 5960045 commit dac7a29
Show file tree
Hide file tree
Showing 27 changed files with 156 additions and 11,483 deletions.
28 changes: 15 additions & 13 deletions apps/shark_studio/api/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@

from pathlib import Path
from random import randint
from turbine_models.custom_models.sd_inference.sd_pipeline import SharkSDPipeline
from turbine_models.custom_models.sdxl_inference.sdxl_compiled_pipeline import (
SharkSDXLPipeline,
)



from apps.shark_studio.api.controlnet import control_adapter_map
Expand All @@ -31,11 +28,8 @@
save_output_img,
)

from apps.shark_studio.modules.ckpt_processing import (
preprocessCKPT,
save_irpa,
)

from subprocess import check_output
EMPTY_SD_MAP = {
"clip": None,
"scheduler": None,
Expand Down Expand Up @@ -67,7 +61,6 @@ def load_script(source, module_name):
:param module_name: name of module to register in sys.modules
:return: loaded module
"""

spec = importlib.util.spec_from_file_location(module_name, source)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
Expand Down Expand Up @@ -118,10 +111,15 @@ def __init__(
self.dynamic_steps = False
self.model_map = custom_module.MODEL_MAP
elif self.is_sdxl:
from turbine_models.custom_models.sdxl_inference.sdxl_compiled_pipeline import (
SharkSDXLPipeline,
)
self.turbine_pipe = SharkSDXLPipeline
self.dynamic_steps = False
self.model_map = EMPTY_SDXL_MAP
else:
from turbine_models.custom_models.sd_inference.sd_pipeline import SharkSDPipeline

self.turbine_pipe = SharkSDPipeline
self.dynamic_steps = True
self.model_map = EMPTY_SD_MAP
Expand Down Expand Up @@ -207,6 +205,10 @@ def prepare_pipe(
self.compiled_pipeline = compiled_pipeline

if custom_weights:
from apps.shark_studio.modules.ckpt_processing import (
preprocessCKPT,
save_irpa,
)
custom_weights = os.path.join(
get_checkpoints_path("checkpoints"),
safe_name(self.base_model_id.split("/")[-1]),
Expand Down Expand Up @@ -534,11 +536,11 @@ def safe_name(name):
global_obj._init()

sd_json = view_json_file(
get_resource_path(os.path.join(cmd_opts.config_dir, "default_sd_config.json"))
get_resource_path(os.path.join(cmd_opts.config_dir, cmd_opts.default_config))
)
sd_kwargs = json.loads(sd_json)
for arg in vars(cmd_opts):
if arg in sd_kwargs:
sd_kwargs[arg] = getattr(cmd_opts, arg)
# for arg in vars(cmd_opts):
# if arg in sd_kwargs:
# sd_kwargs[arg] = getattr(cmd_opts, arg)
for i in shark_sd_fn_dict_input(sd_kwargs):
print(i)
284 changes: 116 additions & 168 deletions apps/shark_studio/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,11 @@
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from cpuinfo import get_cpu_info

# TODO: migrate these utils to studio
from shark.iree_utils.vulkan_utils import (
set_iree_vulkan_runtime_flags,
get_vulkan_target_triple,
get_iree_vulkan_runtime_flags,
)



def get_available_devices():
return ["AMD Radeon 780M => rocm"]
def get_devices_by_name(driver_name):
from shark.iree_utils._common import iree_device_map

Expand Down Expand Up @@ -49,7 +45,7 @@ def get_devices_by_name(driver_name):
device_list.append(f"{device_name} => {driver_name}://{i}")
return device_list

set_iree_runtime_flags()
#set_iree_runtime_flags()

available_devices = []
rocm_devices = get_devices_by_name("rocm")
Expand Down Expand Up @@ -96,55 +92,6 @@ def get_devices_by_name(driver_name):
return available_devices


def set_init_device_flags():
if "vulkan" in cmd_opts.device:
# set runtime flags for vulkan.
set_iree_runtime_flags()

# set triple flag to avoid multiple calls to get_vulkan_triple_flag
device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device)
if not cmd_opts.iree_vulkan_target_triple:
triple = get_vulkan_target_triple(device_name)
if triple is not None:
cmd_opts.iree_vulkan_target_triple = triple
print(
f"Found device {device_name}. Using target triple "
f"{cmd_opts.iree_vulkan_target_triple}."
)
elif "cuda" in cmd_opts.device:
cmd_opts.device = "cuda"
elif "metal" in cmd_opts.device:
device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device)
if not cmd_opts.iree_metal_target_platform:
from shark.iree_utils.metal_utils import get_metal_target_triple

triple = get_metal_target_triple(device_name)
if triple is not None:
cmd_opts.iree_metal_target_platform = triple.split("-")[-1]
print(
f"Found device {device_name}. Using target triple "
f"{cmd_opts.iree_metal_target_platform}."
)
elif "cpu" in cmd_opts.device:
cmd_opts.device = "cpu"


def set_iree_runtime_flags():
# TODO: This function should be device-agnostic and piped properly
# to general runtime driver init.
vulkan_runtime_flags = get_iree_vulkan_runtime_flags()
if cmd_opts.enable_rgp:
vulkan_runtime_flags += [
f"--enable_rgp=true",
f"--vulkan_debug_utils=true",
]
if cmd_opts.device_allocator_heap_key:
vulkan_runtime_flags += [
f"--device_allocator=caching:device_local={cmd_opts.device_allocator_heap_key}",
]
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)


def parse_device(device_str, target_override=""):
from shark.iree_utils.compile_utils import (
clean_device_info,
Expand Down Expand Up @@ -213,6 +160,7 @@ def get_all_devices(driver_name):
driver = get_driver(driver_name)
device_list_src = driver.query_available_devices()
device_list_src.sort(key=lambda d: d["path"])
del driver
return device_list_src


Expand Down Expand Up @@ -281,115 +229,115 @@ def get_opt_flags(model, precision="fp16"):
return iree_flags


def map_device_to_name_path(device, key_combination=3):
"""Gives the appropriate device data (supported name/path) for user
selected execution device
Args:
device (str): user
key_combination (int, optional): choice for mapping value for
device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Raises:
ValueError:
Returns:
str / tuple: returns the mapping str or tuple of mapping str for
the device depending on key_combination value
"""
driver = device.split("://")[0]
device_map = get_device_mapping(driver, key_combination)
try:
device_mapping = device_map[device]
except KeyError:
raise ValueError(f"Device '{device}' is not a valid device.")
return device_mapping

def get_devices_by_name(driver_name):
from shark.iree_utils._common import iree_device_map

device_list = []
try:
driver_name = iree_device_map(driver_name)
device_list_dict = get_all_devices(driver_name)
print(f"{driver_name} devices are available.")
except:
print(f"{driver_name} devices are not available.")
else:
cpu_name = get_cpu_info()["brand_raw"]
for i, device in enumerate(device_list_dict):
device_name = (
cpu_name if device["name"] == "default" else device["name"]
)
if "local" in driver_name:
device_list.append(
f"{device_name} => {driver_name.replace('local', 'cpu')}"
)
else:
# for drivers with single devices
# let the default device be selected without any indexing
if len(device_list_dict) == 1:
device_list.append(f"{device_name} => {driver_name}")
else:
device_list.append(f"{device_name} => {driver_name}://{i}")
return device_list

set_iree_runtime_flags()

available_devices = []
from shark.iree_utils.vulkan_utils import (
get_all_vulkan_devices,
)

vulkaninfo_list = get_all_vulkan_devices()
vulkan_devices = []
id = 0
for device in vulkaninfo_list:
vulkan_devices.append(f"{device.strip()} => vulkan://{id}")
id += 1
if id != 0:
print(f"vulkan devices are available.")
available_devices.extend(vulkan_devices)
metal_devices = get_devices_by_name("metal")
available_devices.extend(metal_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
rocm_devices = get_devices_by_name("rocm")
available_devices.extend(rocm_devices)
cpu_device = get_devices_by_name("cpu-sync")
available_devices.extend(cpu_device)
cpu_device = get_devices_by_name("cpu-task")
available_devices.extend(cpu_device)
return available_devices


# Generate and return a new seed if the provided one is not in the
# supported range (including -1)
def sanitize_seed(seed: int | str):
seed = int(seed)
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
return seed


# take a seed expression in an input format and convert it to
# a list of integers, where possible
def parse_seed_input(seed_input: str | list | int):
if isinstance(seed_input, str):
try:
seed_input = json.loads(seed_input)
except (ValueError, TypeError):
seed_input = None

if isinstance(seed_input, int):
return [seed_input]

if isinstance(seed_input, list) and all(type(seed) is int for seed in seed_input):
return seed_input

raise TypeError(
"Seed input must be an integer or an array of integers in JSON format"
)
# def map_device_to_name_path(device, key_combination=3):
# """Gives the appropriate device data (supported name/path) for user
# selected execution device
# Args:
# device (str): user
# key_combination (int, optional): choice for mapping value for
# device name.
# 1 : path
# 2 : name
# 3 : (name, path)
# Defaults to 3.
# Raises:
# ValueError:
# Returns:
# str / tuple: returns the mapping str or tuple of mapping str for
# the device depending on key_combination value
# """
# driver = device.split("://")[0]
# device_map = get_device_mapping(driver, key_combination)
# try:
# device_mapping = device_map[device]
# except KeyError:
# raise ValueError(f"Device '{device}' is not a valid device.")
# return device_mapping

# def get_devices_by_name(driver_name):
# from shark.iree_utils._common import iree_device_map

# device_list = []
# try:
# driver_name = iree_device_map(driver_name)
# device_list_dict = get_all_devices(driver_name)
# print(f"{driver_name} devices are available.")
# except:
# print(f"{driver_name} devices are not available.")
# else:
# cpu_name = get_cpu_info()["brand_raw"]
# for i, device in enumerate(device_list_dict):
# device_name = (
# cpu_name if device["name"] == "default" else device["name"]
# )
# if "local" in driver_name:
# device_list.append(
# f"{device_name} => {driver_name.replace('local', 'cpu')}"
# )
# else:
# # for drivers with single devices
# # let the default device be selected without any indexing
# if len(device_list_dict) == 1:
# device_list.append(f"{device_name} => {driver_name}")
# else:
# device_list.append(f"{device_name} => {driver_name}://{i}")
# return device_list

# set_iree_runtime_flags()

# available_devices = []
# from shark.iree_utils.vulkan_utils import (
# get_all_vulkan_devices,
# )

# vulkaninfo_list = get_all_vulkan_devices()
# vulkan_devices = []
# id = 0
# for device in vulkaninfo_list:
# vulkan_devices.append(f"{device.strip()} => vulkan://{id}")
# id += 1
# if id != 0:
# print(f"vulkan devices are available.")
# available_devices.extend(vulkan_devices)
# metal_devices = get_devices_by_name("metal")
# available_devices.extend(metal_devices)
# cuda_devices = get_devices_by_name("cuda")
# available_devices.extend(cuda_devices)
# rocm_devices = get_devices_by_name("rocm")
# available_devices.extend(rocm_devices)
# cpu_device = get_devices_by_name("cpu-sync")
# available_devices.extend(cpu_device)
# cpu_device = get_devices_by_name("cpu-task")
# available_devices.extend(cpu_device)
# return available_devices


# # Generate and return a new seed if the provided one is not in the
# # supported range (including -1)
# def sanitize_seed(seed: int | str):
# seed = int(seed)
# uint32_info = np.iinfo(np.uint32)
# uint32_min, uint32_max = uint32_info.min, uint32_info.max
# if seed < uint32_min or seed >= uint32_max:
# seed = randint(uint32_min, uint32_max)
# return seed


# # take a seed expression in an input format and convert it to
# # a list of integers, where possible
# def parse_seed_input(seed_input: str | list | int):
# if isinstance(seed_input, str):
# try:
# seed_input = json.loads(seed_input)
# except (ValueError, TypeError):
# seed_input = None

# if isinstance(seed_input, int):
# return [seed_input]

# if isinstance(seed_input, list) and all(type(seed) is int for seed in seed_input):
# return seed_input

# raise TypeError(
# "Seed input must be an integer or an array of integers in JSON format"
# )
Loading

0 comments on commit dac7a29

Please sign in to comment.