diff --git a/python/turbine_models/model_builder.py b/python/turbine_models/model_builder.py index 22139ca64..03b56dc1a 100644 --- a/python/turbine_models/model_builder.py +++ b/python/turbine_models/model_builder.py @@ -1,6 +1,9 @@ -from transformers import AutoModel, AutoTokenizer, AutoConfig +from transformers import AutoModel, AutoTokenizer, AutoConfig, AutoModelForCausalLM +import safetensors +from iree.compiler.ir import Context import torch import shark_turbine.aot as aot +from shark_turbine.aot import * class HFTransformerBuilder: @@ -19,10 +22,10 @@ def __init__( self, example_input: torch.Tensor, hf_id: str, - auto_model: AutoModel = AutoModel, - auto_tokenizer: AutoTokenizer = None, + auto_model: AutoModel = AutoModelForCausalLM, + auto_tokenizer: AutoTokenizer = AutoTokenizer, auto_config: AutoConfig = None, - hf_auth_token=None, + hf_auth_token="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", ) -> None: self.example_input = example_input self.hf_id = hf_id @@ -40,14 +43,14 @@ def build_model(self) -> None: """ # TODO: check cloud storage for existing ir self.model = self.auto_model.from_pretrained( - self.hf_id, token=self.hf_auth_token, config=self.auto_config + self.hf_id, token=self.hf_auth_token, torch_dtype=torch.float, trust_remote_code=True ) - if self.auto_tokenizer is not None: - self.tokenizer = self.auto_tokenizer.from_pretrained( - self.hf_id, token=self.hf_auth_token - ) - else: - self.tokenizer = None + #if self.auto_tokenizer is not None: + # self.tokenizer = self.auto_tokenizer.from_pretrained( + # self.hf_id, token=self.hf_auth_token, use_fast=False + # ) + #else: + self.tokenizer = None def get_compiled_module(self, save_to: str = None) -> aot.CompiledModule: """ @@ -62,3 +65,33 @@ def get_compiled_module(self, save_to: str = None) -> aot.CompiledModule: module = aot.export(self.model, self.example_input) compiled_binary = module.compile(save_to=save_to) return compiled_binary + + +if __name__ == "__main__": + import sys + hf_id = sys.argv[-1] + safe_name = hf_id.replace("/", "_").replace("-", "_") + inp = torch.zeros(1, 1, dtype=torch.int64) + model = HFTransformerBuilder(inp, hf_id) + mapper=dict() + mod_params = dict(model.model.named_parameters()) + for name in mod_params: + mapper["params." + name] = name +# safetensors.torch.save_file(mod_params, safe_name+".safetensors") + class GlobalModule(CompiledModule): + params = export_parameters(model.model, external=True, external_scope="",) + compute = jittable(model.model.forward) + + def run(self, x=abstractify(inp)): + return self.compute(x) + + print("module defined") + inst = GlobalModule(context=Context()) + print("module inst") + module = CompiledModule.get_mlir_module(inst) +# compiled = module.compile() + print("got mlir module") + with open(safe_name+".mlir", "w+") as f: + f.write(str(module)) + + print("done")