Skip to content

Commit

Permalink
convert compile_candidate.sh to py function
Browse files Browse the repository at this point in the history
  • Loading branch information
RattataKing committed Aug 13, 2024
1 parent 5abb073 commit 5bde3d1
Show file tree
Hide file tree
Showing 2 changed files with 191 additions and 48 deletions.
157 changes: 109 additions & 48 deletions tuning/autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import pickle
import iree.runtime as ireert
import random
import autotune_functions


"""
Sample Usage:
Expand Down Expand Up @@ -94,7 +96,8 @@ class PathConfig:
candidates_dir: Path = field(init=False)
candidate_configs_pkl: Path = field(init=False)
compiled_dir: Path = field(init=False)
compilefailed_dir: Path = field(init=False)
compile_failed_dir: Path = field(init=False)
spec_dir: Path = field(init=False)

output_unilog: Path = field(init=False)
result_summary_log: Path = field(init=False)
Expand All @@ -117,7 +120,8 @@ def __post_init__(self):
self, "candidate_configs_pkl", self.candidates_dir / "configs.pkl"
)
object.__setattr__(self, "compiled_dir", self.candidates_dir / "compiled")
object.__setattr__(self, "compilefailed_dir", self.candidates_dir / "failed")
object.__setattr__(self, "compile_failed_dir", self.candidates_dir / "failed")
object.__setattr__(self, "spec_dir", self.candidates_dir / "spec")
object.__setattr__(self, "output_unilog", self.base_dir / "output.log")
object.__setattr__(
self, "result_summary_log", self.base_dir / "result_summary.log"
Expand All @@ -137,8 +141,17 @@ def _set_run_log(self, run_log: Path):
def get_candidate_mlir_path(self, candidate_id: int) -> Path:
return self.candidates_dir / f"{candidate_id}.mlir"

def get_candidate_vmfb_path(self, candidate_id: int) -> Path:
return self.compiled_dir / f"{candidate_id}.vmfb"

def get_candidate_failed_vmfb_path(self, candidate_id: int) -> Path:
return self.compile_failed_dir / f"{candidate_id}.vmfb"

def get_candidate_config_mlir_path(self, candidate_id: int) -> Path:
return self.candidates_dir / f"{candidate_id}_config.mlir"

def get_candidate_spec_mlir_path(self, candidate_id: int) -> Path:
return self.candidates_dir / "configs" / f"{candidate_id}_spec.mlir"
return self.spec_dir / f"{candidate_id}_spec.mlir"

def get_exe_format(self, path: Path) -> str:
return f"./{path.as_posix()}"
Expand Down Expand Up @@ -778,77 +791,125 @@ def collision_handler(index_hash_list: list[tuple[int, str]]) -> tuple[bool, lis
return collision_detected, unique_indexes


def write_candidate_spec_file(path_config: PathConfig, candidate_id: int) -> None:
# Open the files correctly, ensuring that the file object and Path object are distinct
with path_config.local_config_prolog_mlir.open(
"r"
) as prolog_file, path_config.get_candidate_config_mlir_path(candidate_id).open(
"r"
) as config_file, path_config.local_config_epilog_mlir.open(
"r"
) as epilog_file, (
path_config.spec_dir / f"{candidate_id}_spec.mlir"
).open(
"w"
) as spec_file:

# Read contents from prolog_file, config_file, and epilog_file
prolog_content = prolog_file.read()
config_content = config_file.read()
epilog_content = epilog_file.read()

# Write the contents to the spec file
spec_file.write(prolog_content)
spec_file.write(config_content)
spec_file.write(epilog_content)


def make_CompileDispatchTaskPack_task_list(
args: argparse.Namespace, path_config: PathConfig, candidates: list[int]
):
task_list = []
for candidate_id in candidates:
input_path = path_config.get_candidate_mlir_path(candidate_id)
out_path = path_config.get_candidate_vmfb_path(candidate_id)
mode = args.mode
task_list.append(
autotune_functions.CompileDispatchTaskPack(
candidate_id, input_path.as_posix(), out_path.as_posix(), mode
)
)
return task_list


def compile_candidates(
args: argparse.Namespace,
path_config: PathConfig,
candidates: list[int],
candidate_trackers: list[CandidateTracker],
) -> list[int]:
"""Compile candidate files for tuning and record in candidate_vmfbs.txt. Returns the list of compiled candidate indexes."""
logging.info("compile_candidates()")

task_list = []
for candidate_index in candidates:
mlir_path = candidate_trackers[candidate_index].mlir_path
assert mlir_path is not None
command = [
path_config.get_exe_format(path_config.compile_candidate_sh),
args.mode,
mlir_path.as_posix(),
]
task_list.append(TaskTuple(args, command, check=False))
path_config.compiled_dir.mkdir(parents=True, exist_ok=True)
path_config.compile_failed_dir.mkdir(parents=True, exist_ok=True)
path_config.spec_dir.mkdir(parents=True, exist_ok=True)

num_worker = max(min(args.max_cpu_workers, len(task_list)), 1) # at least 1 worker
multiprocess_progress_wrapper(
num_worker=num_worker, task_list=task_list, function=run_command_wrapper
task_list = make_CompileDispatchTaskPack_task_list(args, path_config, candidates)
num_worker = max(min(args.max_cpu_workers, len(candidates)), 1)
results: list[autotune_functions.CompileDispatchResultPack] = (
multiprocess_progress_wrapper(
num_worker=num_worker,
task_list=task_list,
function=autotune_functions.compile_dispatch,
)
)

compiled_files = sorted(
path_config.compiled_dir.glob("*.vmfb"), key=numerical_sort_key
)
failed_files = sorted(
path_config.compilefailed_dir.glob("*.mlir"), key=numerical_sort_key
)
compile_candidates = []
failed_candidates = []
compiled_candidates_hash_list = []
# process results from autotune_functions.compile_dispatch()
for result in results:
candidate_id = result.candidate_id
candidate_trackers[candidate_id].compilation_successful = result.success

input_path = candidate_trackers[candidate_id].mlir_path
assert input_path is not None
out_path = path_config.get_candidate_vmfb_path(candidate_id)

if result.success == False:
failed_candidates.append(candidate_id)
# move candidates that failed to compile to a different dir. (i.e. candidates/compiled/#.mlir -> candidates/failed/#.mlir)
input_failed_path = path_config.get_candidate_failed_vmfb_path(candidate_id)
input_path.rename(input_failed_path)
candidate_trackers[candidate_id].mlir_path = input_failed_path
if out_path.exists():
out_path.unlink()
continue

# collect hash val for compiled files
compile_candidates.append(candidate_id)
candidate_trackers[candidate_id].compiled_vmfb_path = out_path
hash_val = calculate_md5(out_path)
candidate_trackers[candidate_id].compiled_vmfb_hash = hash_val
compiled_candidates_hash_list.append((candidate_id, hash_val))

# Combine config files if present (0_config.mlir doesn't exist)
if candidate_id == 0:
continue
write_candidate_spec_file(path_config, candidate_id)

total, good, bad = len(task_list), len(compiled_files), len(failed_files)
compiling_rate = good / total * 100
compilation_rate = len(compile_candidates) / len(candidates) * 100
logging.critical(
f"Total: {total} | Compiled: {good} | Failed: {bad} | Compiling Rate: {compiling_rate:.1f}%"
f"Total: {len(candidates)} | Compiled: {len(compile_candidates)} | Failed: {len(failed_candidates)} | Compilation Rate: {compilation_rate:.1f}%"
)

# Update candidate tracker
for failed_file in failed_files:
index = int(failed_file.stem)
candidate_trackers[index].compilation_successful = False
compiled_candidates = []
compiled_candidates_hash_list = []
for compiled_file in compiled_files:
index = int(compiled_file.stem)
compiled_candidates.append(index)
candidate_trackers[index].compilation_successful = True
candidate_trackers[index].compiled_vmfb_path = compiled_file
compiled_vmfb_path = candidate_trackers[index].compiled_vmfb_path
assert compiled_vmfb_path is not None
hash_val = calculate_md5(compiled_vmfb_path)
candidate_trackers[index].compiled_vmfb_hash = hash_val
compiled_candidates_hash_list.append((index, hash_val))

handle_error(
condition=(good == 0), msg="Failed to compile all candidate .mlir files"
condition=(len(compile_candidates) == 0),
msg="Failed to compile all candidate .mlir files",
)
handle_error(
condition=(compiling_rate < 10),
msg=f"Compiling rate [{compiling_rate:.1f}%] < 10%",
condition=(compilation_rate < 10),
msg=f"Compiling rate [{compilation_rate:.1f}%] < 10%",
level=logging.WARNING,
)

# check collision and make a list of unique candidates
collision_detected, unique_indexes = collision_handler(
compiled_candidates_hash_list
)
if collision_detected:
logging.critical(f"Remains [{len(unique_indexes)}] unique candidate indexes")

return compiled_candidates if not collision_detected else unique_indexes
return compile_candidates if not collision_detected else unique_indexes


def parse_dispatch_benchmark_results(
Expand Down
82 changes: 82 additions & 0 deletions tuning/autotune_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import subprocess
from dataclasses import dataclass


@dataclass
class CompileDispatchTaskPack:
candidate_id: int
input_path_str: str
out_path_str: str
mode: str


@dataclass
class CompileDispatchResultPack:
candidate_id: int
success: bool


def compile_dispatch(
compile_dispatch_task_pack: CompileDispatchTaskPack,
) -> CompileDispatchResultPack:
candidate_id = compile_dispatch_task_pack.candidate_id
input_path_str = compile_dispatch_task_pack.input_path_str
out_path_str = compile_dispatch_task_pack.out_path_str
mode = compile_dispatch_task_pack.mode

try:
subprocess.run(
[
"timeout",
"4s",
"./punet.sh",
input_path_str,
"-o",
out_path_str,
"--compile-from=executable-sources",
],
check=True,
stderr=subprocess.DEVNULL,
)
except subprocess.CalledProcessError:
return CompileDispatchResultPack(success=False, candidate_id=candidate_id)

# Check for 'rocm-hsaco-fb' in the output
try:
result = subprocess.run(
["tools/iree-dump-module", out_path_str],
capture_output=True,
text=True,
check=True,
)
if "rocm-hsaco-fb" not in result.stdout:
raise RuntimeError
except (subprocess.CalledProcessError, RuntimeError):
return CompileDispatchResultPack(success=False, candidate_id=candidate_id)

# TODO: Add local logger to print this message
# print(f"Compiling {candidate_id}: success")
return CompileDispatchResultPack(success=True, candidate_id=candidate_id)


def benchmark_dispatch():
# TODO
pass


def compile_model():
# TODO
pass


def benchmark_model():
# TODO
pass


def main():
return


if __name__ == "__main__":
main()

0 comments on commit 5bde3d1

Please sign in to comment.