From b9d6ef5f988cba9f3dce3cda1a618e8f0ac9c53e Mon Sep 17 00:00:00 2001 From: angelayi Date: Tue, 17 Sep 2024 10:55:39 -0700 Subject: [PATCH] Update AOTI package --- .ci/scripts/validate.sh | 36 +++++++------- .github/workflows/pull.yml | 8 +-- .github/workflows/runner-cuda-dtype.yml | 4 +- README.md | 19 ++++---- install/install_requirements.sh | 6 +-- runner/run.cpp | 19 ++------ torchchat/cli/builder.py | 65 ++++++++++++++++++++----- torchchat/cli/cli.py | 12 +++++ torchchat/export.py | 51 ++++++++++++++----- torchchat/generate.py | 7 ++- torchchat/usages/eval.py | 4 +- torchchat/utils/build_utils.py | 20 +++++--- 12 files changed, 166 insertions(+), 85 deletions(-) diff --git a/.ci/scripts/validate.sh b/.ci/scripts/validate.sh index 1f7e889d3..ace9ef18d 100644 --- a/.ci/scripts/validate.sh +++ b/.ci/scripts/validate.sh @@ -133,51 +133,51 @@ function generate_aoti_model_output() { echo "******************************************" echo "************** non-quantized *************" echo "******************************************" - python3 -W ignore torchchat.py export --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path "${MODEL_DIR}/${MODEL_NAME}.so" --device "$TARGET_DEVICE" || exit 1 - python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --dso-path "$MODEL_DIR/${MODEL_NAME}.so" --prompt "$PROMPT" --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1 + python3 -W ignore torchchat.py export --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --output-aoti-package-path "${MODEL_DIR}/${MODEL_NAME}.pt2" --device "$TARGET_DEVICE" || exit 1 + python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --aoti-package-path "$MODEL_DIR/${MODEL_NAME}.pt2" --prompt "$PROMPT" --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1 .ci/scripts/check_gibberish "$MODEL_DIR/output_aoti" echo "******************************************" echo "******* Emb: channel-wise quantized ******" echo "******************************************" - python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1 - python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1 + python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path "$CHECKPOINT_PATH" --output-aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" || exit 1 + python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1 .ci/scripts/check_gibberish "$MODEL_DIR/output_aoti" echo "******************************************" echo "******** Emb: group-wise quantized *******" echo "******************************************" - python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1 - python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1 + python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path "$CHECKPOINT_PATH" --output-aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" || exit 1 + python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1 .ci/scripts/check_gibberish "$MODEL_DIR/output_aoti" echo "***********************************************" echo "******* Emb: 4bit channel-wise quantized ******" echo "***********************************************" - python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"embedding" : {"bitwidth": 4, "groupsize": 0, "packed": "True"}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1 - python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1 + python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"embedding" : {"bitwidth": 4, "groupsize": 0, "packed": "True"}}' --checkpoint-path "$CHECKPOINT_PATH" --output-aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" || exit 1 + python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1 .ci/scripts/check_gibberish "$MODEL_DIR/output_aoti" echo "***********************************************" echo "******** Emb: 4bit group-wise quantized *******" echo "***********************************************" - python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"embedding" : {"bitwidth": 4, "groupsize": 8, "packed": "True"}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1 - python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1 + python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"embedding" : {"bitwidth": 4, "groupsize": 8, "packed": "True"}}' --checkpoint-path "$CHECKPOINT_PATH" --output-aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" || exit 1 + python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1 .ci/scripts/check_gibberish "$MODEL_DIR/output_aoti" if [ "${EXCLUDE_INT8_QUANT:-false}" == false ]; then echo "******************************************" echo "******* INT8 channel-wise quantized ******" echo "******************************************" - python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1 - python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1 + python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path "$CHECKPOINT_PATH" --output-aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" || exit 1 + python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1 .ci/scripts/check_gibberish "$MODEL_DIR/output_aoti" echo "******************************************" echo "******** INT8 group-wise quantized *******" echo "******************************************" - python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1 - python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1 + python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path "$CHECKPOINT_PATH" --output-aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" || exit 1 + python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1 .ci/scripts/check_gibberish "$MODEL_DIR/output_aoti" fi echo "******************************************" @@ -185,8 +185,8 @@ function generate_aoti_model_output() { echo "******************************************" if [[ "$TARGET_DEVICE" != "cuda" || "$DTYPE" == "bfloat16" ]]; then # For CUDA, only bfloat16 makes sense for int4 mm kernel - python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1 - python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1 + python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --output-aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" || exit 1 + python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1 .ci/scripts/check_gibberish "$MODEL_DIR/output_aoti" fi done @@ -285,8 +285,8 @@ function eval_model_sanity_check() { echo "******** INT4 group-wise quantized (AOTI) *******" echo "*************************************************" if [ "$DTYPE" != "float16" ]; then - python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --dynamic-shapes --device "$TARGET_DEVICE" || exit 1 - python3 -W ignore torchchat.py eval --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" --limit 5 > "$MODEL_DIR/output_eval_aoti" || exit 1 + python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --output-aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --dynamic-shapes --device "$TARGET_DEVICE" || exit 1 + python3 -W ignore torchchat.py eval --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" --limit 5 > "$MODEL_DIR/output_eval_aoti" || exit 1 cat "$MODEL_DIR/output_eval_aoti" fi; fi; diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 55fe8f11d..dbbde4c68 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -378,8 +378,8 @@ jobs: echo "::group::Run inference with quantize file" if [ $(uname -s) == Darwin ]; then - python3 torchchat.py export --output-dso-path /tmp/model.so --quantize torchchat/quant_config/cuda.json --checkpoint "./checkpoints/${REPO_NAME}/model.pth" - python3 torchchat.py generate --dso-path /tmp/model.so --checkpoint "./checkpoints/${REPO_NAME}/model.pth"~ + python3 torchchat.py export --output-aoti-package-path /tmp/model.pt2 --quantize torchchat/quant_config/cuda.json --checkpoint "./checkpoints/${REPO_NAME}/model.pth" + python3 torchchat.py generate --aoti-package-path /tmp/model.pt2 --checkpoint "./checkpoints/${REPO_NAME}/model.pth"~ fi echo "::endgroup::" @@ -1023,8 +1023,8 @@ jobs: for dtype in fp32 fp16 bf16 fast fast16; do echo "Running export + runner with dtype=$dtype" - python torchchat.py export --checkpoint-path ${MODEL_DIR}/stories15M.pt --dtype $dtype --output-dso-path /tmp/model.so - ./cmake-out/aoti_run /tmp/model.so -z ${MODEL_DIR}/tokenizer.model -i "${PROMPT}" + python torchchat.py export --checkpoint-path ${MODEL_DIR}/stories15M.pt --dtype $dtype --output-aoti-package-path /tmp/model.pt2 + ./cmake-out/aoti_run /tmp/model.pt2 -z ${MODEL_DIR}/tokenizer.model -i "${PROMPT}" done echo "Tests complete." diff --git a/.github/workflows/runner-cuda-dtype.yml b/.github/workflows/runner-cuda-dtype.yml index a79c262c3..ba0c17766 100644 --- a/.github/workflows/runner-cuda-dtype.yml +++ b/.github/workflows/runner-cuda-dtype.yml @@ -56,9 +56,9 @@ jobs: for DTYPE in bfloat16; do python torchchat.py generate --dtype ${DTYPE} --checkpoint-path ${MODEL_DIR}/stories15M.pt --temperature 0 --prompt "${PROMPT}" --device cuda - python torchchat.py export --checkpoint-path ${MODEL_DIR}/stories15M.pt --output-dso-path /tmp/model.so + python torchchat.py export --checkpoint-path ${MODEL_DIR}/stories15M.pt --output-pt2-path /tmp/model.pt2 - ./cmake-out/aoti_run /tmp/model.so -d CUDA -z ${MODEL_DIR}/tokenizer.model -i "${PROMPT}" + ./cmake-out/aoti_run /tmp/model.pt2 -d CUDA -z ${MODEL_DIR}/tokenizer.model -i "${PROMPT}" done diff --git a/README.md b/README.md index 251cb7fdc..714378d57 100644 --- a/README.md +++ b/README.md @@ -292,13 +292,14 @@ Use the "Max Response Tokens" slider to limit the maximum number of tokens gener ## Desktop/Server Execution ### AOTI (AOT Inductor) -[AOTI](https://pytorch.org/blog/pytorch2-2/) compiles models before execution for faster inference. The process creates a [DSO](https://en.wikipedia.org/wiki/Shared_library) model (represented by a file with extension `.so`) +[AOTI](https://pytorch.org/blog/pytorch2-2/) compiles models before execution for faster inference. The process creates a zipped PT2 file containing all the artifacts generated by AOTInductor, and a [.so](https://en.wikipedia.org/wiki/Shared_library) file with the runnable contents that is then loaded for inference. This can be done with both Python and C++ enviroments. The following example exports and executes the Llama3.1 8B Instruct model. The first command compiles and performs the actual export. -``` -python3 torchchat.py export llama3.1 --output-dso-path exportedModels/llama3.1.so + +```bash +python3 torchchat.py export llama3.1 --output-aoti-package-path exportedModels/llama3_1_artifacts.pt2 ``` > [!NOTE] @@ -310,12 +311,11 @@ case visit our [customization guide](docs/model_customization.md). ### Run in a Python Enviroment -To run in a python enviroment, use the generate subcommand like before, but include the dso file. +To run in a python enviroment, use the generate subcommand like before, but include the pt2 file. +```bash +python3 torchchat.py generate llama3.1 --aoti-package-path exportedModels/llama3_1_artifacts.pt2 --prompt "Hello my name is" ``` -python3 torchchat.py generate llama3.1 --dso-path exportedModels/llama3.1.so --prompt "Hello my name is" -``` -**Note:** Depending on which accelerator is used to generate the .dso file, the command may need the device specified: `--device (cuda | cpu)`. ### Run using our C++ Runner @@ -325,11 +325,10 @@ To run in a C++ enviroment, we need to build the runner binary. torchchat/utils/scripts/build_native.sh aoti ``` -Then run the compiled executable, with the exported DSO from earlier. +Then run the compiled executable, with the pt2. ```bash -cmake-out/aoti_run exportedModels/llama3.1.so -z `python3 torchchat.py where llama3.1`/tokenizer.model -l 3 -i "Once upon a time" +cmake-out/aoti_run exportedModels/llama3_1_artifacts.pt2 -z `python3 torchchat.py where llama3.1`/tokenizer.model -l 3 -i "Once upon a time" ``` -**Note:** Depending on which accelerator is used to generate the .dso file, the runner may need the device specified: `-d (CUDA | CPU)`. ## Mobile Execution diff --git a/install/install_requirements.sh b/install/install_requirements.sh index cd6c302c2..2ac730ce1 100755 --- a/install/install_requirements.sh +++ b/install/install_requirements.sh @@ -47,10 +47,10 @@ fi # NOTE: If a newly-fetched version of the executorch repo changes the value of # PYTORCH_NIGHTLY_VERSION, you should re-run this script to install the necessary # package versions. -PYTORCH_NIGHTLY_VERSION=dev20240901 +PYTORCH_NIGHTLY_VERSION=dev20240913 # Nightly version for torchvision -VISION_NIGHTLY_VERSION=dev20240901 +VISION_NIGHTLY_VERSION=dev20240913 # Nightly version for torchtune TUNE_NIGHTLY_VERSION=dev20240928 @@ -73,7 +73,7 @@ fi # pip packages needed by exir. REQUIREMENTS_TO_INSTALL=( - torch=="2.5.0.${PYTORCH_NIGHTLY_VERSION}" + torch=="2.6.0.${PYTORCH_NIGHTLY_VERSION}" torchvision=="0.20.0.${VISION_NIGHTLY_VERSION}" torchtune=="0.3.0.${TUNE_NIGHTLY_VERSION}" ) diff --git a/runner/run.cpp b/runner/run.cpp index e161c029e..abfbb4584 100644 --- a/runner/run.cpp +++ b/runner/run.cpp @@ -31,10 +31,7 @@ LICENSE file in the root directory of this source tree. #endif #ifdef __AOTI_MODEL__ -#include -#ifdef USE_CUDA -#include -#endif +#include torch::Device aoti_device(torch::kCPU); #else // __ET_MODEL__ @@ -94,7 +91,7 @@ typedef struct { RunState state; // buffers for the "wave" of activations in the forward pass #ifdef __AOTI_MODEL__ - torch::inductor::AOTIModelContainerRunner* runner; + torch::inductor::AOTIModelPackageLoader* runner; #else // __ET_MODEL__ Module* runner; #endif @@ -144,16 +141,8 @@ void build_transformer( malloc_run_state(&t->state, &t->config); #ifdef __AOTI_MODEL__ -#ifdef USE_CUDA - if (aoti_device.type() == torch::kCUDA) { - t->runner = new torch::inductor::AOTIModelContainerRunnerCuda(model_path); - aoti_device = torch::Device(torch::kCUDA); - } else { -#else - { -#endif - t->runner = new torch::inductor::AOTIModelContainerRunnerCpu(model_path); - } + t->runner = new torch::inductor::AOTIModelPackageLoader(model_path); + aoti_device = t->runner->get_metadata()["AOTI_DEVICE_KEY"] == "cpu" ? torch::Device(torch::kCPU) : torch::Device(torch::kCUDA); #else //__ET_MODEL__ t->runner = new Module( /* path to PTE model */ model_path, diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 3abed339a..b1663ee33 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -51,6 +51,7 @@ class BuilderArgs: gguf_path: Optional[Union[Path, str]] = None gguf_kwargs: Optional[Dict[str, Any]] = None dso_path: Optional[Union[Path, str]] = None + aoti_package_path: Optional[Union[Path, str]] = None pte_path: Optional[Union[Path, str]] = None device: Optional[str] = None precision: torch.dtype = torch.float32 @@ -70,28 +71,29 @@ def __post_init__(self): or (self.checkpoint_dir and self.checkpoint_dir.is_dir()) or (self.gguf_path and self.gguf_path.is_file()) or (self.dso_path and Path(self.dso_path).is_file()) + or (self.aoti_package_path and Path(self.aoti_package_path).is_file()) or (self.pte_path and Path(self.pte_path).is_file()) ): raise RuntimeError( "need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path" ) - if self.dso_path and self.pte_path: - raise RuntimeError("specify either DSO path or PTE path, but not both") + if self.aoti_package_path and self.pte_path: + raise RuntimeError("specify either AOTI Package path or PTE path, but not more than one") - if self.checkpoint_path and (self.dso_path or self.pte_path): + if self.checkpoint_path and (self.aoti_package_path or self.pte_path): print( - "Warning: checkpoint path ignored because an exported DSO or PTE path specified" + "Warning: checkpoint path ignored because an exported AOTI or PTE path specified" ) - if self.checkpoint_dir and (self.dso_path or self.pte_path): + if self.checkpoint_dir and (self.aoti_package_path or self.pte_path): print( - "Warning: checkpoint dir ignored because an exported DSO or PTE path specified" + "Warning: checkpoint dir ignored because an exported AOTI or PTE path specified" ) - if self.gguf_path and (self.dso_path or self.pte_path): + if self.gguf_path and (self.aoti_package_path or self.pte_path): print( - "Warning: GGUF path ignored because an exported DSO or PTE path specified" + "Warning: GGUF path ignored because an exported AOTI or PTE path specified" ) - if not (self.dso_path) and not (self.pte_path): + if not (self.aoti_package_path) and not (self.pte_path): self.prefill_possible = True @classmethod @@ -121,6 +123,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": dso_path = getattr(args, "dso_path", None) pte_path = getattr(args, "pte_path", None) + aoti_package_path = getattr(args, "aoti_package_path", None) is_chat_model = False if args.is_chat_model: @@ -131,6 +134,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": checkpoint_dir, dso_path, pte_path, + aoti_package_path, args.gguf_path, ]: if path is not None: @@ -145,6 +149,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": is_chat_model = True output_pte_path = getattr(args, "output_pte_path", None) + output_aoti_package_path = getattr(args, "output_aoti_package_path", None) output_dso_path = getattr(args, "output_dso_path", None) if output_pte_path and args.dtype.startswith("fast"): if args.dtype == "fast": @@ -166,10 +171,11 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": gguf_path=args.gguf_path, gguf_kwargs=None, dso_path=dso_path, + aoti_package_path=aoti_package_path, pte_path=pte_path, device=args.device, precision=dtype, - setup_caches=(output_dso_path or output_pte_path), + setup_caches=(output_dso_path or output_pte_path or output_aoti_package_path), use_distributed=args.distributed, is_chat_model=is_chat_model, dynamic_shapes=getattr(args, "dynamic_shapes", False), @@ -184,6 +190,7 @@ def from_speculative_args(cls, args: argparse.Namespace) -> "BuilderArgs": speculative_builder_args.checkpoint_path = args.draft_checkpoint_path speculative_builder_args.gguf_path = None speculative_builder_args.dso_path = None + speculative_builder_args.aoti_package_path = None speculative_builder_args.pte_path = None return speculative_builder_args @@ -494,11 +501,12 @@ def _initialize_model( ) -> Model: print("Loading model...") - if builder_args.gguf_path and (builder_args.dso_path or builder_args.pte_path): + if builder_args.gguf_path and (builder_args.dso_path or builder_args.aoti_package_path or builder_args.pte_path): print("Setting gguf_kwargs for generate.") is_dso = builder_args.dso_path is not None + is_aoti_package = builder_args.aoti_package_path is not None is_pte = builder_args.pte_path is not None - assert not (is_dso and is_pte) + assert not (is_dso and is_aoti_package and is_pte) assert builder_args.gguf_kwargs is None # TODO: make GGUF load independent of backend # currently not working because AVX int_mm broken @@ -532,6 +540,39 @@ def _initialize_model( ) except: raise RuntimeError(f"Failed to load AOTI compiled {builder_args.dso_path}") + + elif builder_args.aoti_package_path: + if not is_cuda_or_cpu_device(builder_args.device): + print( + f"Cannot load specified PT2 to {builder_args.device}. Attempting to load model to CPU instead" + ) + builder_args.device = "cpu" + + # assert ( + # quantize is None or quantize == "{ }" + # ), "quantize not valid for exported PT2 model. Specify quantization during export." + + with measure_time("Time to load model: {time:.02f} seconds"): + model = _load_model(builder_args) + device_sync(device=builder_args.device) + + try: + # Replace model forward with the AOT-compiled forward + # This is a hacky way to quickly demo AOTI's capability. + # model is still a Python object, and any mutation to its + # attributes will NOT be seen on by AOTI-compiled forward + # function, e.g. calling model.setup_cache will NOT touch + # AOTI compiled and maintained model buffers such as kv_cache. + from torch._inductor.package import load_package + aoti_compiled_model = load_package( + str(builder_args.aoti_package_path.absolute()) + ) + model.forward = aoti_compiled_model + metadata = aoti_compiled_model.get_metadata() + builder_args.device = metadata["AOTI_DEVICE_KEY"] + except: + raise RuntimeError(f"Failed to load AOTI compiled {builder_args.aoti_package_path}") + elif builder_args.pte_path: if not is_cpu_device(builder_args.device): print( diff --git a/torchchat/cli/cli.py b/torchchat/cli/cli.py index 1d624c6c4..45743ddad 100644 --- a/torchchat/cli/cli.py +++ b/torchchat/cli/cli.py @@ -191,6 +191,12 @@ def _add_export_output_path_args(parser) -> None: default=None, help="Output to the specified AOT Inductor .dso model file", ) + exclusive_parser.add_argument( + "--output-aoti-package-path", + type=str, + default=None, + help="Output directory for AOTInductor compiled artifacts", + ) def _add_export_args(parser) -> None: @@ -220,6 +226,12 @@ def _add_exported_input_path_args(parser) -> None: default=None, help="Use the specified AOT Inductor .dso model file", ) + exclusive_parser.add_argument( + "--aoti-package-path", + type=Path, + default=None, + help="Use the specified directory containing AOT Inductor compiled files", + ) exclusive_parser.add_argument( "--pte-path", type=Path, diff --git a/torchchat/export.py b/torchchat/export.py index c024e9deb..4e6c22e67 100644 --- a/torchchat/export.py +++ b/torchchat/export.py @@ -11,6 +11,7 @@ import torch.nn as nn from torch.export import Dim +import torch._inductor from torchchat.cli.builder import ( _initialize_model, @@ -35,8 +36,9 @@ def export_for_server( model: nn.Module, device: Optional[str] = "cpu", - output_path: str = "model.dso", + output_path: str = "model.pt2", dynamic_shapes: bool = False, + package: bool = True, ) -> str: """ Export the model using AOT Compile to get a .dso for server use cases. @@ -49,7 +51,7 @@ def export_for_server( The path to the exported model. """ if dynamic_shapes: - input = ( + example_inputs = ( torch.tensor([[1, 9038, 2501, 263, 931]], dtype=torch.int, device=device), torch.tensor([0, 1, 2, 3, 4], dtype=torch.int, device=device), ) @@ -58,21 +60,31 @@ def export_for_server( # Specify that the first dimension of each input is that batch size dynamic_shapes = {"tokens": {1: seq}, "input_pos": {0: seq}} else: - input = ( + example_inputs = ( torch.tensor([[1]], dtype=torch.int, device=device), torch.tensor([0], dtype=torch.int, device=device), ) dynamic_shapes = None with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]): - so = torch._export.aot_compile( + metadata = {} # TODO: put more metadata here + options = {"aot_inductor.package": package, "aot_inductor.metadata": metadata} + if not package: + options = {"aot_inductor.output_path": output_path} + + path = torch._export.aot_compile( model, - args=input, - options={"aot_inductor.output_path": output_path}, + example_inputs, dynamic_shapes=dynamic_shapes, + options=options, ) - print(f"The generated DSO model can be found at: {so}") - return so + + if package: + from torch._inductor.package import package_aoti + path = package_aoti(output_path, path) + + print(f"The generated packaged model can be found at: {path}") + return path """ @@ -338,14 +350,16 @@ def main(args): print(f"Using device={builder_args.device}") set_precision(builder_args.precision) - set_backend(dso=args.output_dso_path, pte=args.output_pte_path) + set_backend(dso=args.output_dso_path, pte=args.output_pte_path, aoti_package=args.output_aoti_package_path) builder_args.dso_path = None builder_args.pte_path = None + builder_args.aoti_package_path = None builder_args.setup_caches = True output_pte_path = args.output_pte_path output_dso_path = args.output_dso_path + output_aoti_package_path = args.output_aoti_package_path if output_pte_path and builder_args.device != "cpu": print( @@ -387,6 +401,7 @@ def main(args): ) model_to_pte = model model_to_dso = model + model_to_aoti_package = model else: if output_pte_path: _set_gguf_kwargs(builder_args, is_et=True, context="export") @@ -396,13 +411,14 @@ def main(args): ) _unset_gguf_kwargs(builder_args) - if output_dso_path: + if output_dso_path or output_aoti_package_path: _set_gguf_kwargs(builder_args, is_et=False, context="export") - model_to_dso = _initialize_model( + model_to_aoti_package = _initialize_model( builder_args, quantize, support_tensor_subclass=False, ) + model_to_dso = model_to_aoti_package _unset_gguf_kwargs(builder_args) with torch.no_grad(): @@ -419,9 +435,22 @@ def main(args): if output_dso_path: output_dso_path = str(os.path.abspath(output_dso_path)) print(f"Exporting model using AOT Inductor to {output_dso_path}") + print("WARNING!! The path of compiling a dso is deprecated. Please use aoti_package_path to create a .pt2 artifact instead.") export_for_server( model_to_dso, builder_args.device, output_dso_path, builder_args.dynamic_shapes, + package=False, + ) + + if output_aoti_package_path: + output_aoti_package_path = str(os.path.abspath(output_aoti_package_path)) + print(f"Exporting model using AOT Inductor to {output_aoti_package_path}") + export_for_server( + model_to_aoti_package, + builder_args.device, + output_aoti_package_path, + builder_args.dynamic_shapes, + package=True, ) diff --git a/torchchat/generate.py b/torchchat/generate.py index c38fcaff5..f04f19756 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -155,6 +155,8 @@ def validate_build( reason = "model compilation for prefill" if self.compile: reason = "model compilation" + if builder_args.aoti_package_path: + model_type = "PT2" if builder_args.dso_path: model_type = "DSO" if builder_args.pte_path: @@ -168,7 +170,10 @@ def validate_build( def from_args(cls, args): dso_path = getattr(args, "dso_path", None) pte_path = getattr(args, "pte_path", None) - sequential_prefill = args.sequential_prefill or bool(dso_path) or bool(pte_path) + aoti_package_path = getattr(args, "aoti_package_path", None) + sequential_prefill = ( + args.sequential_prefill or bool(aoti_package_path) or bool(pte_path) or bool(dso_path) + ) # Validate that all image prompts exist before expensive model load if image_prompts := getattr(args, "image_prompts", None): diff --git a/torchchat/usages/eval.py b/torchchat/usages/eval.py index 5993c3781..b708e5840 100644 --- a/torchchat/usages/eval.py +++ b/torchchat/usages/eval.py @@ -260,7 +260,7 @@ def main(args) -> None: if compile: assert not ( - builder_args.dso_path or builder_args.pte_path + builder_args.dso_path or builder_args.pte_path or builder_args.aoti_package_path ), "cannot compile exported model" model_forward = torch.compile( model_forward, mode="reduce-overhead", dynamic=True, fullgraph=True @@ -288,6 +288,8 @@ def main(args) -> None: ) if builder_args.dso_path: print(f"For model {builder_args.dso_path}") + elif builder_args.aoti_package_path: + print(f"For model {builder_args.aoti_package_path}") elif builder_args.pte_path: print(f"For model {builder_args.pte_path}") elif builder_args.checkpoint_path: diff --git a/torchchat/utils/build_utils.py b/torchchat/utils/build_utils.py index abda72d70..f4292a095 100644 --- a/torchchat/utils/build_utils.py +++ b/torchchat/utils/build_utils.py @@ -69,42 +69,46 @@ def unpack_packed_weights( active_builder_args_dso = None active_builder_args_pte = None +active_builder_args_aoti_package = None -def set_backend(dso, pte): +def set_backend(dso, pte, aoti_package): global active_builder_args_dso global active_builder_args_pte active_builder_args_dso = dso + active_builder_args_aoti_package = aoti_package active_builder_args_pte = pte def use_aoti_backend() -> bool: global active_builder_args_dso + global active_builder_args_aoti_package global active_builder_args_pte # eager == aoti, which is when backend has not been explicitly set - if (not active_builder_args_dso) and not (active_builder_args_pte): + if (not active_builder_args_pte) and (not active_builder_args_aoti_package): return True - if active_builder_args_pte and active_builder_args_dso: + if active_builder_args_pte and active_builder_args_aoti_package: raise RuntimeError( - "code generation needs to choose different implementations for DSO and PTE path. Please only use one export option, and call export twice if necessary!" + "code generation needs to choose different implementations for AOTI and PTE path. Please only use one export option, and call export twice if necessary!" ) - return bool(active_builder_args_dso) + return bool(active_builder_args_dso) or bool(active_builder_args_aoti_package) def use_et_backend() -> bool: global active_builder_args_dso + global active_builder_args_aoti_package global active_builder_args_pte # eager == aoti, which is when backend has not been explicitly set - if not (active_builder_args_pte or active_builder_args_dso): + if (not active_builder_args_pte) and (not active_builder_args_aoti_package): return False - if active_builder_args_pte and active_builder_args_dso: + if active_builder_args_pte and active_builder_args_aoti_package: raise RuntimeError( - "code generation needs to choose different implementations for DSO and PTE path. Please only use one export option, and call export twice if necessary!" + "code generation needs to choose different implementations for AOTI and PTE path. Please only use one export option, and call export twice if necessary!" ) return bool(active_builder_args_pte)