Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
saienduri committed Feb 21, 2024
1 parent a465a48 commit ac89976
Show file tree
Hide file tree
Showing 5 changed files with 4 additions and 33 deletions.
7 changes: 0 additions & 7 deletions models/turbine_models/custom_models/sd_inference/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,6 @@
help="Specify vulkan target triple or rocm/cuda target device.",
)
parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296")
parser.add_argument(
"--upload_ir",
action=argparse.BooleanOptionalAction,
default=False,
help="upload IR to turbine tank",
)


def export_clip_model(
Expand Down Expand Up @@ -129,7 +123,6 @@ def main(self, inp=AbstractTensor(1, 77, dtype=torch.int64)):
args.device,
args.iree_target_triple,
args.vulkan_max_allocation,
args.upload_ir,
)
safe_name = args.hf_model_name.split("/")[-1].strip()
safe_name = re.sub("-", "_", safe_name)
Expand Down
7 changes: 0 additions & 7 deletions models/turbine_models/custom_models/sd_inference/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,6 @@
help="Specify vulkan target triple or rocm/cuda target device.",
)
parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296")
parser.add_argument(
"--upload_ir",
action=argparse.BooleanOptionalAction,
default=False,
help="upload IR to turbine tank",
)


class UnetModel(torch.nn.Module):
Expand Down Expand Up @@ -167,7 +161,6 @@ def main(
args.device,
args.iree_target_triple,
args.vulkan_max_allocation,
args.upload_ir,
)
safe_name = utils.create_safe_name(args.hf_model_name, "-unet")
with open(f"{safe_name}.mlir", "w+") as f:
Expand Down
7 changes: 0 additions & 7 deletions models/turbine_models/custom_models/sd_inference/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,6 @@
)
parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296")
parser.add_argument("--variant", type=str, default="decode")
parser.add_argument(
"--upload_ir",
action=argparse.BooleanOptionalAction,
default=False,
help="upload IR to turbine tank",
)


class VaeModel(torch.nn.Module):
Expand Down Expand Up @@ -156,7 +150,6 @@ def main(self, inp=AbstractTensor(*sample, dtype=torch.float32)):
args.iree_target_triple,
args.vulkan_max_allocation,
args.variant,
args.upload_ir,
)
safe_name = utils.create_safe_name(args.hf_model_name, "-vae")
with open(f"{safe_name}.mlir", "w+") as f:
Expand Down
7 changes: 0 additions & 7 deletions models/turbine_models/custom_models/stateless_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,6 @@
action="store_true",
help="Compile LLM with StreamingLLM optimizations",
)
parser.add_argument(
"--upload_ir",
action=argparse.BooleanOptionalAction,
default=False,
help="upload IR to turbine tank",
)


def generate_schema(num_layers):
Expand Down Expand Up @@ -413,7 +407,6 @@ def evict_kvcache_space(self):
args.vulkan_max_allocation,
args.streaming_llm,
args.vmfb_path,
args.upload_ir,
)
safe_name = args.hf_model_name.split("/")[-1].strip()
safe_name = re.sub("-", "_", safe_name)
Expand Down
9 changes: 4 additions & 5 deletions models/turbine_models/turbine_tank/turbine_tank.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@
print(
f"turbine_tank local cache is located at {WORKDIR} . You may change this by assigning the TURBINE_TANK_CACHE_DIR environment variable."
)
os.makedirs(WORKDIR, exist_ok=True)

storage_account_key = "XSsr+KqxBLxXzRtFv3QbbdsAxdwDGe661Q1xY4ziMRtpCazN8W6HZePi6nwud5RNLC5Y7e410abg+AStyzmX1A=="
storage_account_name = "tankturbine"
connection_string = "DefaultEndpointsProtocol=https;AccountName=tankturbine;AccountKey=XSsr+KqxBLxXzRtFv3QbbdsAxdwDGe661Q1xY4ziMRtpCazN8W6HZePi6nwud5RNLC5Y7e410abg+AStyzmX1A==;EndpointSuffix=core.windows.net"
container_name = "tankturbine"
storage_account_key = os.environ.get("AZURE_STORAGE_ACCOUNT_KEY")
storage_account_name = os.environ.get("AZURE_STORAGE_ACCOUNT_NAME")
connection_string = os.environ.get("AZURE_CONNECTION_STRING")
container_name = os.environ.get("AZURE_CONTAINER_NAME")


def get_short_git_sha() -> str:
Expand Down

0 comments on commit ac89976

Please sign in to comment.