From 94c09e402adcdff60c235f603022fc74a6e123a5 Mon Sep 17 00:00:00 2001 From: saienduri Date: Wed, 14 Feb 2024 03:52:23 -0800 Subject: [PATCH] turbine tank --- models/requirements.txt | 2 + .../custom_models/sd_inference/clip.py | 32 ++ .../custom_models/sd_inference/unet.py | 31 ++ .../custom_models/sd_inference/utils.py | 2 +- .../custom_models/sd_inference/vae.py | 31 ++ .../custom_models/stateless_llama.py | 42 +- .../turbine_models/turbine_tank/run_models.py | 404 ++++++++++++++++++ .../turbine_tank/turbine_tank.py | 143 +++++++ 8 files changed, 681 insertions(+), 6 deletions(-) create mode 100644 models/turbine_models/turbine_tank/run_models.py create mode 100644 models/turbine_models/turbine_tank/turbine_tank.py diff --git a/models/requirements.txt b/models/requirements.txt index 4d2d16a56..99678eb68 100644 --- a/models/requirements.txt +++ b/models/requirements.txt @@ -5,3 +5,5 @@ transformers accelerate diffusers==0.24.0 brevitas @ git+https://github.com/Xilinx/brevitas.git@6695e8df7f6a2c7715b9ed69c4b78157376bb60b +# turbine tank downloading/uploading +azure-storage-blob diff --git a/models/turbine_models/custom_models/sd_inference/clip.py b/models/turbine_models/custom_models/sd_inference/clip.py index 996d5fb83..a2ab030ef 100644 --- a/models/turbine_models/custom_models/sd_inference/clip.py +++ b/models/turbine_models/custom_models/sd_inference/clip.py @@ -16,6 +16,7 @@ import torch import torch._dynamo as dynamo from transformers import CLIPTextModel, CLIPTokenizer +from turbine_models.turbine_tank import turbine_tank import argparse @@ -46,6 +47,18 @@ help="Specify vulkan target triple or rocm/cuda target device.", ) parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296") +parser.add_argument( + "--download_ir", + action=argparse.BooleanOptionalAction, + default=True, + help="download IR from turbine tank", +) +parser.add_argument( + "--upload_ir", + action=argparse.BooleanOptionalAction, + default=False, + help="upload IR to turbine tank", +) def export_clip_model( @@ -57,6 +70,8 @@ def export_clip_model( device=None, target_triple=None, max_alloc=None, + download_ir=False, + upload_ir=False, ): # Load the tokenizer and text encoder to tokenize and encode the text. tokenizer = CLIPTokenizer.from_pretrained( @@ -64,6 +79,10 @@ def export_clip_model( subfolder="tokenizer", token=hf_auth_token, ) + + if download_ir: + return turbine_tank.downloadModelArtifacts(hf_model_name + "-clip"), tokenizer + text_encoder_model = CLIPTextModel.from_pretrained( hf_model_name, subfolder="text_encoder", @@ -94,6 +113,15 @@ def main(self, inp=AbstractTensor(1, 77, dtype=torch.int64)): module_str = str(CompiledModule.get_mlir_module(inst)) safe_name = utils.create_safe_name(hf_model_name, "-clip") + if upload_ir: + with open(f"{safe_name}.mlir", "w+") as f: + f.write(module_str) + model_name_upload = hf_model_name.replace("/", "_") + model_name_upload += "-clip" + turbine_tank.uploadToBlobStorage( + str(os.path.abspath(f"{safe_name}.mlir")), + f"{model_name_upload}/{model_name_upload}.mlir", + ) if compile_to != "vmfb": return module_str, tokenizer else: @@ -102,6 +130,8 @@ def main(self, inp=AbstractTensor(1, 77, dtype=torch.int64)): if __name__ == "__main__": args = parser.parse_args() + if args.upload_ir and args.download_ir: + raise ValueError("upload_ir and download_ir can't both be true") mod_str, _ = export_clip_model( args.hf_model_name, args.hf_auth_token, @@ -111,6 +141,8 @@ def main(self, inp=AbstractTensor(1, 77, dtype=torch.int64)): args.device, args.iree_target_triple, args.vulkan_max_allocation, + args.download_ir, + args.upload_ir, ) safe_name = args.hf_model_name.split("/")[-1].strip() safe_name = re.sub("-", "_", safe_name) diff --git a/models/turbine_models/custom_models/sd_inference/unet.py b/models/turbine_models/custom_models/sd_inference/unet.py index 272c7af7f..d193ded78 100644 --- a/models/turbine_models/custom_models/sd_inference/unet.py +++ b/models/turbine_models/custom_models/sd_inference/unet.py @@ -18,6 +18,7 @@ import safetensors import argparse +from turbine_models.turbine_tank import turbine_tank parser = argparse.ArgumentParser() parser.add_argument( @@ -53,6 +54,18 @@ help="Specify vulkan target triple or rocm/cuda target device.", ) parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296") +parser.add_argument( + "--download_ir", + action=argparse.BooleanOptionalAction, + default=True, + help="download IR from turbine tank", +) +parser.add_argument( + "--upload_ir", + action=argparse.BooleanOptionalAction, + default=False, + help="upload IR to turbine tank", +) class UnetModel(torch.nn.Module): @@ -90,7 +103,12 @@ def export_unet_model( device=None, target_triple=None, max_alloc=None, + download_ir=False, + upload_ir=False, ): + if download_ir: + return turbine_tank.downloadModelArtifacts(hf_model_name + "-unet") + mapper = {} utils.save_external_weights( mapper, unet_model, external_weights, external_weight_path @@ -125,6 +143,15 @@ def main( module_str = str(CompiledModule.get_mlir_module(inst)) safe_name = utils.create_safe_name(hf_model_name, "-unet") + if upload_ir: + with open(f"{safe_name}.mlir", "w+") as f: + f.write(module_str) + model_name_upload = hf_model_name.replace("/", "_") + model_name_upload += "-unet" + turbine_tank.uploadToBlobStorage( + str(os.path.abspath(f"{safe_name}.mlir")), + f"{model_name_upload}/{model_name_upload}.mlir", + ) if compile_to != "vmfb": return module_str else: @@ -133,6 +160,8 @@ def main( if __name__ == "__main__": args = parser.parse_args() + if args.upload_ir and args.download_ir: + raise ValueError("upload_ir and download_ir can't both be true") unet_model = UnetModel( args.hf_model_name, args.hf_auth_token, @@ -150,6 +179,8 @@ def main( args.device, args.iree_target_triple, args.vulkan_max_allocation, + args.download_ir, + args.upload_ir, ) safe_name = utils.create_safe_name(args.hf_model_name, "-unet") with open(f"{safe_name}.mlir", "w+") as f: diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 37787fd3a..c4898dac7 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -79,7 +79,7 @@ def compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name): with open(f"{safe_name}.vmfb", "wb+") as f: f.write(flatbuffer_blob) print("Saved to", safe_name + ".vmfb") - exit() + return def create_safe_name(hf_model_name, model_name_str): diff --git a/models/turbine_models/custom_models/sd_inference/vae.py b/models/turbine_models/custom_models/sd_inference/vae.py index 03ef85556..2aef05bcf 100644 --- a/models/turbine_models/custom_models/sd_inference/vae.py +++ b/models/turbine_models/custom_models/sd_inference/vae.py @@ -18,6 +18,7 @@ import safetensors import argparse +from turbine_models.turbine_tank import turbine_tank parser = argparse.ArgumentParser() parser.add_argument( @@ -54,6 +55,18 @@ ) parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296") parser.add_argument("--variant", type=str, default="decode") +parser.add_argument( + "--download_ir", + action=argparse.BooleanOptionalAction, + default=True, + help="download IR from turbine tank", +) +parser.add_argument( + "--upload_ir", + action=argparse.BooleanOptionalAction, + default=False, + help="upload IR to turbine tank", +) class VaeModel(torch.nn.Module): @@ -89,7 +102,12 @@ def export_vae_model( target_triple=None, max_alloc=None, variant="decode", + download_ir=False, + upload_ir=False, ): + if download_ir: + return turbine_tank.downloadModelArtifacts(hf_model_name + "-" + variant) + mapper = {} utils.save_external_weights( mapper, vae_model, external_weights, external_weight_path @@ -113,6 +131,15 @@ def main(self, inp=AbstractTensor(*sample, dtype=torch.float32)): module_str = str(CompiledModule.get_mlir_module(inst)) safe_name = utils.create_safe_name(hf_model_name, "-vae") + if upload_ir: + with open(f"{safe_name}.mlir", "w+") as f: + f.write(module_str) + model_name_upload = hf_model_name.replace("/", "_") + model_name_upload = model_name_upload + "-" + variant + turbine_tank.uploadToBlobStorage( + str(os.path.abspath(f"{safe_name}.mlir")), + f"{model_name_upload}/{model_name_upload}.mlir", + ) if compile_to != "vmfb": return module_str else: @@ -121,6 +148,8 @@ def main(self, inp=AbstractTensor(*sample, dtype=torch.float32)): if __name__ == "__main__": args = parser.parse_args() + if args.upload_ir and args.download_ir: + raise ValueError("upload_ir and download_ir can't both be true") vae_model = VaeModel( args.hf_model_name, args.hf_auth_token, @@ -139,6 +168,8 @@ def main(self, inp=AbstractTensor(*sample, dtype=torch.float32)): args.iree_target_triple, args.vulkan_max_allocation, args.variant, + args.download_ir, + args.upload_ir, ) safe_name = utils.create_safe_name(args.hf_model_name, "-vae") with open(f"{safe_name}.mlir", "w+") as f: diff --git a/models/turbine_models/custom_models/stateless_llama.py b/models/turbine_models/custom_models/stateless_llama.py index 762690603..5e4c7ca1a 100644 --- a/models/turbine_models/custom_models/stateless_llama.py +++ b/models/turbine_models/custom_models/stateless_llama.py @@ -2,6 +2,7 @@ import sys import re import json +from turbine_models.turbine_tank import turbine_tank os.environ["TORCH_LOGS"] = "dynamic" from transformers import AutoTokenizer, AutoModelForCausalLM @@ -61,6 +62,18 @@ action="store_true", help="Compile LLM with StreamingLLM optimizations", ) +parser.add_argument( + "--download_ir", + action=argparse.BooleanOptionalAction, + default=True, + help="download IR from turbine tank", +) +parser.add_argument( + "--upload_ir", + action=argparse.BooleanOptionalAction, + default=False, + help="upload IR to turbine tank", +) def generate_schema(num_layers): @@ -107,7 +120,18 @@ def export_transformer_model( vulkan_max_allocation=None, streaming_llm=False, vmfb_path=None, + download_ir=False, + upload_ir=False, ): + tokenizer = AutoTokenizer.from_pretrained( + hf_model_name, + use_fast=False, + token=hf_auth_token, + ) + + if download_ir: + return turbine_tank.downloadModelArtifacts(hf_model_name), tokenizer + mod = AutoModelForCausalLM.from_pretrained( hf_model_name, torch_dtype=torch.float, @@ -121,11 +145,7 @@ def export_transformer_model( if precision == "f16": mod = mod.half() dtype = torch.float16 - tokenizer = AutoTokenizer.from_pretrained( - hf_model_name, - use_fast=False, - token=hf_auth_token, - ) + # TODO: generate these values instead of magic numbers NUM_LAYERS = mod.config.num_hidden_layers HEADS = getattr(mod.config, "num_key_value_heads", None) @@ -319,6 +339,14 @@ def evict_kvcache_space(self): module_str = str(CompiledModule.get_mlir_module(inst)) safe_name = hf_model_name.split("/")[-1].strip() safe_name = re.sub("-", "_", safe_name) + if upload_ir: + with open(f"{safe_name}.mlir", "w+") as f: + f.write(module_str) + model_name_upload = hf_model_name.replace("/", "_") + turbine_tank.uploadToBlobStorage( + str(os.path.abspath(f"{safe_name}.mlir")), + f"{model_name_upload}/{model_name_upload}.mlir", + ) if compile_to != "vmfb": return module_str, tokenizer else: @@ -382,6 +410,8 @@ def evict_kvcache_space(self): if __name__ == "__main__": args = parser.parse_args() + if args.upload_ir and args.download_ir: + raise ValueError("upload_ir and download_ir can't both be true") mod_str, _ = export_transformer_model( args.hf_model_name, args.hf_auth_token, @@ -395,6 +425,8 @@ def evict_kvcache_space(self): args.vulkan_max_allocation, args.streaming_llm, args.vmfb_path, + args.download_ir, + args.upload_ir, ) safe_name = args.hf_model_name.split("/")[-1].strip() safe_name = re.sub("-", "_", safe_name) diff --git a/models/turbine_models/turbine_tank/run_models.py b/models/turbine_models/turbine_tank/run_models.py new file mode 100644 index 000000000..5d612c4ee --- /dev/null +++ b/models/turbine_models/turbine_tank/run_models.py @@ -0,0 +1,404 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import argparse +from turbine_models.custom_models.sd_inference import ( + clip, + clip_runner, + unet, + unet_runner, + vae, + vae_runner, +) + +from turbine_models.custom_models.sd_inference import utils +import torch +import os +import turbine_models.custom_models.stateless_llama as llama +import difflib +from turbine_models.turbine_tank import turbine_tank + +parser = argparse.ArgumentParser() +parser.add_argument( + "--download_ir", + action=argparse.BooleanOptionalAction, + default=False, + help="download IR from turbine tank", +) +parser.add_argument( + "--upload_ir", + action=argparse.BooleanOptionalAction, + default=True, + help="upload IR to turbine tank", +) + +os.environ["TORCH_LOGS"] = "dynamic" +from shark_turbine.aot import * +from turbine_models.custom_models import llm_runner + +from turbine_models.gen_external_params.gen_external_params import ( + gen_external_params, +) + +DEFAULT_PROMPT = """[INST] <> +Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <> hi what are you? [/INST] +""" + + +def check_output_string(reference, output): + # Calculate and print diff + diff = difflib.unified_diff( + reference.splitlines(keepends=True), + output.splitlines(keepends=True), + fromfile="reference", + tofile="output", + lineterm="", + ) + return "".join(diff) + + +def run_llama_model(download_ir=False, upload_ir=True): + if not download_ir: + gen_external_params( + hf_model_name="Trelis/Llama-2-7b-chat-hf-function-calling-v2", + hf_auth_token=None, + ) + llama.export_transformer_model( + hf_model_name="Trelis/Llama-2-7b-chat-hf-function-calling-v2", + hf_auth_token=None, + compile_to="vmfb", + external_weights="safetensors", + # external_weight_file="Llama-2-7b-chat-hf-function-calling-v2_f16_int4.safetensors", Do not export weights because this doesn't get quantized + quantization="int4", + precision="f16", + device="llvm-cpu", + target_triple="host", + download_ir=download_ir, + upload_ir=upload_ir, + ) + + if download_ir: + return + + torch_str_cache_path = ( + f"models/turbine_models/tests/vmfb_comparison_cached_torch_output_f16_int4.txt" + ) + # if cached, just read + if os.path.exists(torch_str_cache_path): + with open(torch_str_cache_path, "r") as f: + torch_str = f.read() + else: + torch_str = llm_runner.run_torch_llm( + "Trelis/Llama-2-7b-chat-hf-function-calling-v2", None, DEFAULT_PROMPT + ) + + with open(torch_str_cache_path, "w") as f: + f.write(torch_str) + + turbine_str = llm_runner.run_llm( + "local-task", + DEFAULT_PROMPT, + "Llama_2_7b_chat_hf_function_calling_v2.vmfb", + "Trelis/Llama-2-7b-chat-hf-function-calling-v2", + None, + f"Llama_2_7b_chat_hf_function_calling_v2_f16_int4.safetensors", + ) + + result = check_output_string(torch_str, turbine_str) + + # clean up + os.remove("Llama_2_7b_chat_hf_function_calling_v2_f16_int4.safetensors") + os.remove("Llama_2_7b_chat_hf_function_calling_v2.vmfb") + os.remove("Llama_2_7b_chat_hf_function_calling_v2.mlir") + + return result + + +arguments = { + "hf_auth_token": None, + "hf_model_name": "CompVis/stable-diffusion-v1-4", + "batch_size": 1, + "height": 512, + "width": 512, + "run_vmfb": True, + "compile_to": None, + "external_weight_path": "", + "vmfb_path": "", + "external_weights": None, + "device": "local-task", + "iree_target_triple": "", + "vulkan_max_allocation": "4294967296", + "prompt": "a photograph of an astronaut riding a horse", + "in_channels": 4, +} + + +unet_model = unet.UnetModel( + # This is a public model, so no auth required + "CompVis/stable-diffusion-v1-4", + None, +) + +vae_model = vae.VaeModel( + # This is a public model, so no auth required + "CompVis/stable-diffusion-v1-4", + None, +) + + +def run_clip_model(download_ir=False, upload_ir=True): + clip.export_clip_model( + # This is a public model, so no auth required + "CompVis/stable-diffusion-v1-4", + None, + "vmfb", + "safetensors", + "stable_diffusion_v1_4_clip.safetensors", + "cpu", + download_ir=download_ir, + upload_ir=upload_ir, + ) + + if download_ir: + return + + arguments["external_weight_path"] = "stable_diffusion_v1_4_clip.safetensors" + arguments["vmfb_path"] = "stable_diffusion_v1_4_clip.vmfb" + turbine = clip_runner.run_clip( + arguments["device"], + arguments["prompt"], + arguments["vmfb_path"], + arguments["hf_model_name"], + arguments["hf_auth_token"], + arguments["external_weight_path"], + ) + torch_output = clip_runner.run_torch_clip( + arguments["hf_model_name"], arguments["hf_auth_token"], arguments["prompt"] + ) + err = utils.largest_error(torch_output, turbine[0]) + if err < 9e-5: + result = "CLIP SUCCESS: " + str(err) + else: + result = "CLIP FAILURE: " + str(err) + + # clean up + os.remove("stable_diffusion_v1_4_clip.safetensors") + os.remove("stable_diffusion_v1_4_clip.vmfb") + os.remove("stable_diffusion_v1_4_clip.mlir") + + return result + + +def run_unet_model(download_ir=False, upload_ir=True): + unet.export_unet_model( + unet_model, + # This is a public model, so no auth required + "CompVis/stable-diffusion-v1-4", + arguments["batch_size"], + arguments["height"], + arguments["width"], + None, + "vmfb", + "safetensors", + "stable_diffusion_v1_4_unet.safetensors", + "cpu", + download_ir=download_ir, + upload_ir=upload_ir, + ) + + if download_ir: + return + + arguments["external_weight_path"] = "stable_diffusion_v1_4_unet.safetensors" + arguments["vmfb_path"] = "stable_diffusion_v1_4_unet.vmfb" + sample = torch.rand( + arguments["batch_size"], + arguments["in_channels"], + arguments["height"] // 8, + arguments["width"] // 8, + dtype=torch.float32, + ) + timestep = torch.zeros(1, dtype=torch.float32) + encoder_hidden_states = torch.rand(2, 77, 768, dtype=torch.float32) + + turbine = unet_runner.run_unet( + arguments["device"], + sample, + timestep, + encoder_hidden_states, + arguments["vmfb_path"], + arguments["hf_model_name"], + arguments["hf_auth_token"], + arguments["external_weight_path"], + ) + torch_output = unet_runner.run_torch_unet( + arguments["hf_model_name"], + arguments["hf_auth_token"], + sample, + timestep, + encoder_hidden_states, + ) + err = utils.largest_error(torch_output, turbine) + if err < 9e-5: + result = "UNET SUCCESS: " + str(err) + else: + result = "UNET FAILURE: " + str(err) + + # clean up + os.remove("stable_diffusion_v1_4_unet.safetensors") + os.remove("stable_diffusion_v1_4_unet.vmfb") + os.remove("stable_diffusion_v1_4_unet.mlir") + + return result + + +def run_vae_decode(download_ir=False, upload_ir=True): + vae.export_vae_model( + vae_model, + # This is a public model, so no auth required + "CompVis/stable-diffusion-v1-4", + arguments["batch_size"], + arguments["height"], + arguments["width"], + None, + "vmfb", + "safetensors", + "stable_diffusion_v1_4_vae.safetensors", + "cpu", + variant="decode", + download_ir=download_ir, + upload_ir=upload_ir, + ) + + if download_ir: + return + + arguments["external_weight_path"] = "stable_diffusion_v1_4_vae.safetensors" + arguments["vmfb_path"] = "stable_diffusion_v1_4_vae.vmfb" + example_input = torch.rand( + arguments["batch_size"], + 4, + arguments["height"] // 8, + arguments["width"] // 8, + dtype=torch.float32, + ) + turbine = vae_runner.run_vae( + arguments["device"], + example_input, + arguments["vmfb_path"], + arguments["hf_model_name"], + arguments["hf_auth_token"], + arguments["external_weight_path"], + ) + torch_output = vae_runner.run_torch_vae( + arguments["hf_model_name"], + arguments["hf_auth_token"], + "decode", + example_input, + ) + err = utils.largest_error(torch_output, turbine) + if err < 9e-5: + result = "VAE DECODE SUCCESS: " + str(err) + else: + result = "VAE DECODE FAILURE: " + str(err) + + # clean up + os.remove("stable_diffusion_v1_4_vae.safetensors") + os.remove("stable_diffusion_v1_4_vae.vmfb") + os.remove("stable_diffusion_v1_4_vae.mlir") + + return result + + +def run_vae_encode(download_ir=False, upload_ir=True): + vae.export_vae_model( + vae_model, + # This is a public model, so no auth required + "CompVis/stable-diffusion-v1-4", + arguments["batch_size"], + arguments["height"], + arguments["width"], + None, + "vmfb", + "safetensors", + "stable_diffusion_v1_4_vae.safetensors", + "cpu", + variant="encode", + download_ir=download_ir, + upload_ir=upload_ir, + ) + + if download_ir: + return + + arguments["external_weight_path"] = "stable_diffusion_v1_4_vae.safetensors" + arguments["vmfb_path"] = "stable_diffusion_v1_4_vae.vmfb" + example_input = torch.rand( + arguments["batch_size"], + 3, + arguments["height"], + arguments["width"], + dtype=torch.float32, + ) + turbine = vae_runner.run_vae( + arguments["device"], + example_input, + arguments["vmfb_path"], + arguments["hf_model_name"], + arguments["hf_auth_token"], + arguments["external_weight_path"], + ) + torch_output = vae_runner.run_torch_vae( + arguments["hf_model_name"], + arguments["hf_auth_token"], + "encode", + example_input, + ) + err = utils.largest_error(torch_output, turbine) + if err < 2e-3: + result = "VAE ENCODE SUCCESS: " + str(err) + else: + result = "VAE ENCODE FAILURE: " + str(err) + + # clean up + os.remove("stable_diffusion_v1_4_vae.safetensors") + os.remove("stable_diffusion_v1_4_vae.vmfb") + os.remove("stable_diffusion_v1_4_vae.mlir") + + return result + + +if __name__ == "__main__": + args = parser.parse_args() + + if args.upload_ir and args.download_ir: + raise ValueError("upload_ir and download_ir can't both be true") + + if args.upload_ir: + result = "Turbine Tank Results\n" + llama_result = run_llama_model(args.download_ir, args.upload_ir) + result += llama_result + "\n" + clip_result = run_clip_model(args.download_ir, args.upload_ir) + result += clip_result + "\n" + unet_result = run_unet_model(args.download_ir, args.upload_ir) + result += unet_result + "\n" + vae_decode_result = run_vae_decode(args.download_ir, args.upload_ir) + result += vae_decode_result + "\n" + vae_encode_result = run_vae_encode(args.download_ir, args.upload_ir) + result += vae_encode_result + "\n" + f = open("daily_report.txt", "a") + f.write(result) + f.close() + turbine_tank.uploadToBlobStorage( + str(os.path.abspath("daily_report.txt")), "daily_report.txt" + ) + os.remove("daily_report.txt") + else: + run_llama_model(args.download_ir, args.upload_ir) + run_clip_model(args.download_ir, args.upload_ir) + run_unet_model(args.download_ir, args.upload_ir) + run_vae_decode(args.download_ir, args.upload_ir) + run_vae_encode(args.download_ir, args.upload_ir) diff --git a/models/turbine_models/turbine_tank/turbine_tank.py b/models/turbine_models/turbine_tank/turbine_tank.py new file mode 100644 index 000000000..92a294e3e --- /dev/null +++ b/models/turbine_models/turbine_tank/turbine_tank.py @@ -0,0 +1,143 @@ +from azure.storage.blob import BlobServiceClient + +import subprocess +import datetime +import os +from pathlib import Path + +custom_path = os.getenv("TURBINE_TANK_CACHE_DIR") +if custom_path is not None: + if not os.path.exists(custom_path): + os.mkdir(custom_path) + + WORKDIR = custom_path + + print(f"Using {WORKDIR} as local turbine_tank cache directory.") +else: + WORKDIR = os.path.join(str(Path.home()), ".local/turbine_tank/") + print( + f"turbine_tank local cache is located at {WORKDIR} . You may change this by assigning the TURBINE_TANK_CACHE_DIR environment variable." + ) +os.makedirs(WORKDIR, exist_ok=True) + +storage_account_key = "XSsr+KqxBLxXzRtFv3QbbdsAxdwDGe661Q1xY4ziMRtpCazN8W6HZePi6nwud5RNLC5Y7e410abg+AStyzmX1A==" +storage_account_name = "tankturbine" +connection_string = "DefaultEndpointsProtocol=https;AccountName=tankturbine;AccountKey=XSsr+KqxBLxXzRtFv3QbbdsAxdwDGe661Q1xY4ziMRtpCazN8W6HZePi6nwud5RNLC5Y7e410abg+AStyzmX1A==;EndpointSuffix=core.windows.net" +container_name = "tankturbine" + + +def get_short_git_sha() -> str: + try: + return ( + subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]) + .decode("utf-8") + .strip() + ) + except FileNotFoundError: + return None + + +def uploadToBlobStorage(file_path, file_name): + # create our prefix (we use this to keep track of when and what version of turbine is being used) + today = str(datetime.date.today()) + commit = get_short_git_sha() + prefix = today + "_" + commit + blob_service_client = BlobServiceClient.from_connection_string(connection_string) + blob_client = blob_service_client.get_blob_client( + container=container_name, blob=prefix + "/" + file_name + ) + blob = blob_client.from_connection_string( + conn_str=connection_string, + container_name=container_name, + blob_name=blob_client.blob_name, + ) + # we check to see if we already uploaded the blob (don't want to duplicate) + if blob.exists(): + print( + f"model artifacts have already been uploaded for {today} on the same github commit ({commit})" + ) + return + # upload to azure storage container tankturbine + with open(file_path, "rb") as data: + blob_client.upload_blob(data) + print(f"Uploaded {file_name}.") + + +def checkAndRemoveIfDownloadedOld(model_name: str, model_dir: str, prefix: str): + if os.path.isdir(model_dir) and len(os.listdir(model_dir)) == 1: + for item in os.listdir(model_dir): + item_path = os.path.join(model_dir, item) + # model artifacts already downloaded and up to date + # we check if model artifacts are behind using the prefix (day + git_sha) + if os.path.isdir(item_path) and item == prefix: + return True + # model artifacts are behind, so remove for new download + if os.path.isdir(item_path) and os.path.isfile( + os.path.join(item_path, model_name + ".mlir") + ): + os.remove(os.path.join(item_path, model_name + ".mlir")) + os.rmdir(item_path) + return False + # did not downloaded this model artifacts yet + return False + + +def download_public_folder(model_name: str, prefix: str, model_dir: str): + """Downloads a folder of blobs in azure container.""" + blob_service_client = BlobServiceClient.from_connection_string(connection_string) + container_client = blob_service_client.get_container_client( + container=container_name + ) + blob_list = container_client.list_blobs(name_starts_with=prefix) + + # go through the blobs with our target prefix + # example prefix: "2024-02-13_26d6428/CompVis_stable-diffusion-v1-4-clip" + for blob in blob_list: + blob_client = blob_service_client.get_blob_client( + container=container_name, blob=blob.name + ) + # create path if directory doesn't exist locally + dest_path = model_dir + if not os.path.isdir(dest_path): + os.makedirs(dest_path) + # download blob into local turbine tank cache + with open( + file=os.path.join(model_dir, model_name + ".mlir"), mode="wb" + ) as sample_blob: + download_stream = blob_client.download_blob() + sample_blob.write(download_stream.readall()) + + +def downloadModelArtifacts(model_name: str) -> str: + model_name = model_name.replace("/", "_") + container_client = BlobServiceClient.from_connection_string( + connection_string + ).get_container_client(container=container_name) + blob_list = container_client.list_blobs() + # get the latest blob uploaded to turbine tank (can't use [] notation for blob_list) + for blob in blob_list: + latest_blob = blob + # get the prefix for the latest blob (2024-02-13_26d6428) + download_latest_prefix = latest_blob.name.split("/")[0] + model_dir = os.path.join(WORKDIR, model_name) + # check if we already downloaded the model artifacts for this day + commit + exists = checkAndRemoveIfDownloadedOld( + model_name=model_name, model_dir=model_dir, prefix=download_latest_prefix + ) + if exists: + print("Already downloaded most recent version") + return "NA" + # download the model artifacts (passing in the model name, path in azure storage to model artifacts, local directory to store) + download_public_folder( + model_name, + download_latest_prefix + "/" + model_name, + os.path.join(model_dir, download_latest_prefix), + ) + model_dir = os.path.join(WORKDIR, model_name + "/" + download_latest_prefix) + mlir_filename = os.path.join(model_dir, model_name + ".mlir") + print( + f"Verifying that model artifacts were downloaded successfully to {mlir_filename}..." + ) + assert os.path.exists(mlir_filename), f"MLIR not found at {mlir_filename}" + + return mlir_filename