Skip to content

Commit

Permalink
Revert "vllm check"
Browse files Browse the repository at this point in the history
This reverts commit 29f4259.
  • Loading branch information
Ubuntu committed Jan 19, 2024
1 parent 29f4259 commit 5443da5
Showing 1 changed file with 11 additions and 44 deletions.
55 changes: 11 additions & 44 deletions python/turbine_models/model_builder.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
from transformers import AutoModel, AutoTokenizer, AutoConfig, AutoModelForCausalLM
import safetensors
from iree.compiler.ir import Context
from transformers import AutoModel, AutoTokenizer, AutoConfig
import torch
import shark_turbine.aot as aot
from shark_turbine.aot import *


class HFTransformerBuilder:
Expand All @@ -22,10 +19,10 @@ def __init__(
self,
example_input: torch.Tensor,
hf_id: str,
auto_model: AutoModel = AutoModelForCausalLM,
auto_tokenizer: AutoTokenizer = AutoTokenizer,
auto_model: AutoModel = AutoModel,
auto_tokenizer: AutoTokenizer = None,
auto_config: AutoConfig = None,
hf_auth_token="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",
hf_auth_token=None,
) -> None:
self.example_input = example_input
self.hf_id = hf_id
Expand All @@ -43,14 +40,14 @@ def build_model(self) -> None:
"""
# TODO: check cloud storage for existing ir
self.model = self.auto_model.from_pretrained(
self.hf_id, token=self.hf_auth_token, torch_dtype=torch.float, trust_remote_code=True
self.hf_id, token=self.hf_auth_token, config=self.auto_config
)
#if self.auto_tokenizer is not None:
# self.tokenizer = self.auto_tokenizer.from_pretrained(
# self.hf_id, token=self.hf_auth_token, use_fast=False
# )
#else:
self.tokenizer = None
if self.auto_tokenizer is not None:
self.tokenizer = self.auto_tokenizer.from_pretrained(
self.hf_id, token=self.hf_auth_token
)
else:
self.tokenizer = None

def get_compiled_module(self, save_to: str = None) -> aot.CompiledModule:
"""
Expand All @@ -65,33 +62,3 @@ def get_compiled_module(self, save_to: str = None) -> aot.CompiledModule:
module = aot.export(self.model, self.example_input)
compiled_binary = module.compile(save_to=save_to)
return compiled_binary


if __name__ == "__main__":
import sys
hf_id = sys.argv[-1]
safe_name = hf_id.replace("/", "_").replace("-", "_")
inp = torch.zeros(1, 1, dtype=torch.int64)
model = HFTransformerBuilder(inp, hf_id)
mapper=dict()
mod_params = dict(model.model.named_parameters())
for name in mod_params:
mapper["params." + name] = name
# safetensors.torch.save_file(mod_params, safe_name+".safetensors")
class GlobalModule(CompiledModule):
params = export_parameters(model.model, external=True, external_scope="",)
compute = jittable(model.model.forward)

def run(self, x=abstractify(inp)):
return self.compute(x)

print("module defined")
inst = GlobalModule(context=Context())
print("module inst")
module = CompiledModule.get_mlir_module(inst)
# compiled = module.compile()
print("got mlir module")
with open(safe_name+".mlir", "w+") as f:
f.write(str(module))

print("done")

0 comments on commit 5443da5

Please sign in to comment.