From 5443da5d2c8c5e87c8a3881da0de3ed24bab9fad Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 19 Jan 2024 18:29:56 +0000 Subject: [PATCH] Revert "vllm check" This reverts commit 29f4259a35a39ba3509feb6b0cb63ebe64e16214. --- python/turbine_models/model_builder.py | 55 ++++++-------------------- 1 file changed, 11 insertions(+), 44 deletions(-) diff --git a/python/turbine_models/model_builder.py b/python/turbine_models/model_builder.py index 03b56dc1a..22139ca64 100644 --- a/python/turbine_models/model_builder.py +++ b/python/turbine_models/model_builder.py @@ -1,9 +1,6 @@ -from transformers import AutoModel, AutoTokenizer, AutoConfig, AutoModelForCausalLM -import safetensors -from iree.compiler.ir import Context +from transformers import AutoModel, AutoTokenizer, AutoConfig import torch import shark_turbine.aot as aot -from shark_turbine.aot import * class HFTransformerBuilder: @@ -22,10 +19,10 @@ def __init__( self, example_input: torch.Tensor, hf_id: str, - auto_model: AutoModel = AutoModelForCausalLM, - auto_tokenizer: AutoTokenizer = AutoTokenizer, + auto_model: AutoModel = AutoModel, + auto_tokenizer: AutoTokenizer = None, auto_config: AutoConfig = None, - hf_auth_token="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", + hf_auth_token=None, ) -> None: self.example_input = example_input self.hf_id = hf_id @@ -43,14 +40,14 @@ def build_model(self) -> None: """ # TODO: check cloud storage for existing ir self.model = self.auto_model.from_pretrained( - self.hf_id, token=self.hf_auth_token, torch_dtype=torch.float, trust_remote_code=True + self.hf_id, token=self.hf_auth_token, config=self.auto_config ) - #if self.auto_tokenizer is not None: - # self.tokenizer = self.auto_tokenizer.from_pretrained( - # self.hf_id, token=self.hf_auth_token, use_fast=False - # ) - #else: - self.tokenizer = None + if self.auto_tokenizer is not None: + self.tokenizer = self.auto_tokenizer.from_pretrained( + self.hf_id, token=self.hf_auth_token + ) + else: + self.tokenizer = None def get_compiled_module(self, save_to: str = None) -> aot.CompiledModule: """ @@ -65,33 +62,3 @@ def get_compiled_module(self, save_to: str = None) -> aot.CompiledModule: module = aot.export(self.model, self.example_input) compiled_binary = module.compile(save_to=save_to) return compiled_binary - - -if __name__ == "__main__": - import sys - hf_id = sys.argv[-1] - safe_name = hf_id.replace("/", "_").replace("-", "_") - inp = torch.zeros(1, 1, dtype=torch.int64) - model = HFTransformerBuilder(inp, hf_id) - mapper=dict() - mod_params = dict(model.model.named_parameters()) - for name in mod_params: - mapper["params." + name] = name -# safetensors.torch.save_file(mod_params, safe_name+".safetensors") - class GlobalModule(CompiledModule): - params = export_parameters(model.model, external=True, external_scope="",) - compute = jittable(model.model.forward) - - def run(self, x=abstractify(inp)): - return self.compute(x) - - print("module defined") - inst = GlobalModule(context=Context()) - print("module inst") - module = CompiledModule.get_mlir_module(inst) -# compiled = module.compile() - print("got mlir module") - with open(safe_name+".mlir", "w+") as f: - f.write(str(module)) - - print("done")