From 34259d47332565c53ea2c46afd7e39f126284e6f Mon Sep 17 00:00:00 2001 From: angelayi Date: Tue, 17 Sep 2024 15:29:02 -0700 Subject: [PATCH] [aoti] Remove need for -l in cmake call --- README.md | 2 +- runner/run.cpp | 59 +++++++++++++++++++-------------------------- torchchat/export.py | 13 ++++++++-- 3 files changed, 37 insertions(+), 37 deletions(-) diff --git a/README.md b/README.md index e1f585c06..b7b393c8b 100644 --- a/README.md +++ b/README.md @@ -290,7 +290,7 @@ torchchat/utils/scripts/build_native.sh aoti Then run the compiled executable, with the pt2. ```bash -cmake-out/aoti_run exportedModels/llama3_1_artifacts.pt2 -z `python3 torchchat.py where llama3.1`/tokenizer.model -l 3 -i "Once upon a time" +cmake-out/aoti_run exportedModels/llama3_1_artifacts.pt2 -z `python3 torchchat.py where llama3.1`/tokenizer.model -i "Once upon a time" ``` ## Mobile Execution diff --git a/runner/run.cpp b/runner/run.cpp index fea6dc3ce..5c56fb14a 100644 --- a/runner/run.cpp +++ b/runner/run.cpp @@ -32,8 +32,6 @@ LICENSE file in the root directory of this source tree. #ifdef __AOTI_MODEL__ #include -torch::Device aoti_device(torch::kCPU); - #else // __ET_MODEL__ #include #include @@ -88,9 +86,11 @@ typedef struct { typedef struct { Config config; // the hyperparameters of the architecture (the blueprint) RunState state; // buffers for the "wave" of activations in the forward pass + std::unordered_map metadata; #ifdef __AOTI_MODEL__ torch::inductor::AOTIModelPackageLoader* runner; + #else // __ET_MODEL__ Module* runner; #endif @@ -129,19 +129,9 @@ void read_checkpoint(char* checkpoint, Config* config) { void build_transformer( Transformer* t, - char* model_path, - int vocab_size, - int seq_len) { - // read in the Config and the Weights from the model - // read_checkpoint(model_path, &t->config); - // allocate the RunState buffers - t->config.vocab_size = vocab_size; - t->config.seq_len = seq_len; - malloc_run_state(&t->state, &t->config); - + char* model_path) { #ifdef __AOTI_MODEL__ t->runner = new torch::inductor::AOTIModelPackageLoader(model_path); - aoti_device = t->runner->get_metadata()["AOTI_DEVICE_KEY"] == "cpu" ? torch::Device(torch::kCPU) : torch::Device(torch::kCUDA); #else //__ET_MODEL__ t->runner = new Module( /* path to PTE model */ model_path, @@ -193,6 +183,9 @@ float* forward(Transformer* transformer, int token, int pos) { torch::Tensor token_tensor = torch::from_blob(token_buffer, {1, 1}, torch::kLong); torch::Tensor pos_tensor = torch::from_blob(pos_buffer, {1}, torch::kLong); + torch::Device aoti_device = transformer->runner->get_metadata()["AOTI_DEVICE_KEY"] == "cpu" + ? torch::Device(torch::kCPU) + : torch::Device(torch::kCUDA); std::vector inputs{ token_tensor.to(aoti_device), pos_tensor.to(aoti_device)}; @@ -880,26 +873,25 @@ int main(int argc, char* argv[]) { system_prompt = argv[i + 1]; } else if (argv[i][1] == 'l') { llama_ver = atoi(argv[i + 1]); -#ifdef __AOTI_MODEL__ - } else if (argv[i][1] == 'd') { -#ifdef USE_CUDA - if (strcasecmp(argv[i + 1], "CUDA") == 0) { - aoti_device = torch::Device(torch::kCUDA); - } else -#endif - if (strcasecmp(argv[i + 1], "CPU") == 0) { - aoti_device = torch::Device(torch::kCPU); - } else { - fprintf(stderr, "Unknown device %s", argv[i + 1]); - exit(1); - } -#endif } else { error_usage(); } } + if (model_path == NULL) { + fprintf(stderr, "No model_path provided."); + error_usage(); + } + + Transformer transformer; + build_transformer(&transformer, model_path); + +#ifdef __AOTI_MODEL__ + ModelType model_type = get_model_type(std::stoi(transformer.runner->get_metadata()["tokenizer_type"])); +#else // __ET_MODEL__ ModelType model_type = get_model_type(llama_ver); +#endif + if (model_type == UNKNOWN_MODEL) { fprintf( stderr, @@ -908,11 +900,6 @@ int main(int argc, char* argv[]) { error_usage(); } - if (model_path == NULL) { - fprintf(stderr, "No model_path provided."); - error_usage(); - } - if (tokenizer_path == NULL) { fprintf(stderr, "No tokenizer_path provided."); error_usage(); @@ -935,8 +922,12 @@ int main(int argc, char* argv[]) { vocab_size = tokenizer->vocab_size(); } - Transformer transformer; - build_transformer(&transformer, model_path, vocab_size, steps); + // read in the Config and the Weights from the model + // read_checkpoint(model_path, &t->config); + // allocate the RunState buffers + transformer.config.vocab_size = vocab_size; + transformer.config.seq_len = steps; + malloc_run_state(&transformer.state, &transformer.config); Sampler sampler; build_sampler(&sampler, vocab_size, temperature, topp, rng_seed); diff --git a/torchchat/export.py b/torchchat/export.py index 033733b5a..7c76afc68 100644 --- a/torchchat/export.py +++ b/torchchat/export.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import os -from typing import Optional +from typing import Dict, Optional import torch import torch.nn as nn @@ -39,6 +39,7 @@ def export_for_server( output_path: str = "model.pt2", dynamic_shapes: bool = False, package: bool = True, + metadata: Optional[Dict[str, str]] = None, ) -> str: """ Export the model using AOT Compile to get a .dso for server use cases. @@ -67,7 +68,7 @@ def export_for_server( dynamic_shapes = None with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]): - metadata = {} # TODO: put more metadata here + metadata = metadata or {} options = {"aot_inductor.package": package, "aot_inductor.metadata": metadata} if not package: options = {"aot_inductor.output_path": output_path} @@ -373,6 +374,7 @@ def main(args): # TODO: clean this up # This mess is because ET does not support _weight_int4pack_mm right now + tokenizer_args = None if not builder_args.gguf_path: # tokenizer needed for quantization so get that here, try: @@ -443,6 +445,12 @@ def main(args): if output_aoti_package_path: output_aoti_package_path = str(os.path.abspath(output_aoti_package_path)) + + tokenizer_type = "0" + if tokenizer_args is not None: + tokenizer_type = "2" if tokenizer_args.is_sentencepiece else "3" + + metadata = {"tokenizer_type": tokenizer_type} print(f"Exporting model using AOT Inductor to {output_aoti_package_path}") export_for_server( model_to_aoti_package, @@ -450,4 +458,5 @@ def main(args): output_aoti_package_path, builder_args.dynamic_shapes, package=True, + metadata=metadata, )