-
Notifications
You must be signed in to change notification settings - Fork 229
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[aoti] Remove need for -l in cmake call #1159
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,13 +5,13 @@ | |
# LICENSE file in the root directory of this source tree. | ||
|
||
import os | ||
from typing import Optional | ||
from typing import Dict, Optional | ||
|
||
import torch | ||
import torch._inductor | ||
import torch.nn as nn | ||
|
||
from torch.export import Dim | ||
import torch._inductor | ||
|
||
from torchchat.cli.builder import ( | ||
_initialize_model, | ||
|
@@ -39,6 +39,7 @@ def export_for_server( | |
output_path: str = "model.pt2", | ||
dynamic_shapes: bool = False, | ||
package: bool = True, | ||
metadata: Optional[Dict[str, str]] = None, | ||
) -> str: | ||
""" | ||
Export the model using AOT Compile to get a .dso for server use cases. | ||
|
@@ -67,8 +68,10 @@ def export_for_server( | |
dynamic_shapes = None | ||
|
||
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]): | ||
metadata = {} # TODO: put more metadata here | ||
options = {"aot_inductor.package": package, "aot_inductor.metadata": metadata} | ||
options = { | ||
"aot_inductor.package": package, | ||
"aot_inductor.metadata": metadata or {}, | ||
} | ||
if not package: | ||
options = {"aot_inductor.output_path": output_path} | ||
|
||
|
@@ -81,6 +84,7 @@ def export_for_server( | |
|
||
if package: | ||
from torch._inductor.package import package_aoti | ||
|
||
path = package_aoti(output_path, path) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry, somehow my editor added on a bunch of formatting changes here.. hope it's not too confusing otherwise I can try to remove them. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No worries, these are good lint fixes |
||
print(f"The generated packaged model can be found at: {path}") | ||
|
@@ -102,13 +106,13 @@ def export_for_server( | |
from typing import Any, Dict, Tuple, Union | ||
|
||
import executorch.exir as exir | ||
from executorch.backends.xnnpack._passes.convert_to_linear import ( | ||
ConvertToLinearPass, | ||
) | ||
|
||
from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( | ||
XnnpackDynamicallyQuantizedPartitioner, | ||
) | ||
from executorch.backends.xnnpack._passes.convert_to_linear import ( | ||
ConvertToLinearPass, | ||
) | ||
from executorch.exir import EdgeProgramManager, to_edge | ||
|
||
from executorch.exir.capture._config import ( | ||
|
@@ -166,18 +170,22 @@ def __init__(self, attention: Attention): | |
|
||
self.wo = attention.wo | ||
|
||
max_batch_size, n_heads, max_seq_length, head_dim = ( | ||
attention.kv_cache[0].k_cache.shape | ||
) | ||
max_batch_size, n_heads, max_seq_length, head_dim = attention.kv_cache[ | ||
0 | ||
].k_cache.shape | ||
cache_dtype = attention.kv_cache[0].k_cache.dtype | ||
# The `Attention` module being replaced can have multiple KV caches | ||
# (denoted by `cache_lanes`). Thus we follow the same setup format | ||
# as in `Attention.setup_cache`. | ||
cache_lanes = len(attention.kv_cache) | ||
self.kv_cache = nn.ModuleList([ | ||
CustomKVCache(max_batch_size, max_seq_length, n_heads, head_dim, cache_dtype) | ||
for _ in range(cache_lanes) | ||
]) | ||
self.kv_cache = nn.ModuleList( | ||
[ | ||
CustomKVCache( | ||
max_batch_size, max_seq_length, n_heads, head_dim, cache_dtype | ||
) | ||
for _ in range(cache_lanes) | ||
] | ||
) | ||
|
||
self.n_heads = attention.n_heads | ||
self.head_dim = attention.head_dim | ||
|
@@ -215,9 +223,7 @@ def forward(self, x, freqs_cis, mask, input_pos=None, cache_lane: int = 0): | |
return self.wo(output) | ||
|
||
def replace_attention_with_custom_sdpa_attention(module: nn.Module): | ||
from executorch.extension.llm.custom_ops import ( # noqa | ||
sdpa_with_kv_cache, | ||
) | ||
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa | ||
|
||
for name, child in module.named_children(): | ||
if isinstance(child, Attention): | ||
|
@@ -238,7 +244,9 @@ def _to_core_aten( | |
raise ValueError( | ||
f"Expected passed in model to be an instance of fx.GraphModule, got {type(model)}" | ||
) | ||
core_aten_ep = export_for_training(model, example_inputs, dynamic_shapes=dynamic_shapes) | ||
core_aten_ep = export_for_training( | ||
model, example_inputs, dynamic_shapes=dynamic_shapes | ||
) | ||
if verbose: | ||
logging.info(f"Core ATen graph:\n{core_aten_ep.graph}") | ||
return core_aten_ep | ||
|
@@ -350,7 +358,11 @@ def main(args): | |
|
||
print(f"Using device={builder_args.device}") | ||
set_precision(builder_args.precision) | ||
set_backend(dso=args.output_dso_path, pte=args.output_pte_path, aoti_package=args.output_aoti_package_path) | ||
set_backend( | ||
dso=args.output_dso_path, | ||
pte=args.output_pte_path, | ||
aoti_package=args.output_aoti_package_path, | ||
) | ||
|
||
builder_args.dso_path = None | ||
builder_args.pte_path = None | ||
|
@@ -372,6 +384,7 @@ def main(args): | |
|
||
# TODO: clean this up | ||
# This mess is because ET does not support _weight_int4pack_mm right now | ||
tokenizer_args = None | ||
if not builder_args.gguf_path: | ||
# tokenizer needed for quantization so get that here, | ||
try: | ||
|
@@ -382,9 +395,8 @@ def main(args): | |
|
||
if builder_args.max_seq_length is None: | ||
if ( | ||
(output_dso_path is not None or output_aoti_package_path is not None) | ||
and not builder_args.dynamic_shapes | ||
): | ||
output_dso_path is not None or output_aoti_package_path is not None | ||
) and not builder_args.dynamic_shapes: | ||
print("Setting max_seq_length to 300 for DSO export.") | ||
builder_args.max_seq_length = 300 | ||
elif output_pte_path is not None: | ||
|
@@ -397,7 +409,8 @@ def main(args): | |
quantize, | ||
tokenizer, | ||
max_seq_length=builder_args.max_seq_length, | ||
support_tensor_subclass=output_dso_path is None and output_aoti_package_path is None, | ||
support_tensor_subclass=output_dso_path is None | ||
and output_aoti_package_path is None, | ||
) | ||
model_to_pte = model | ||
model_to_dso = model | ||
|
@@ -435,7 +448,9 @@ def main(args): | |
if output_dso_path: | ||
output_dso_path = str(os.path.abspath(output_dso_path)) | ||
print(f"Exporting model using AOT Inductor to {output_dso_path}") | ||
print("WARNING!! The path of compiling a dso is deprecated. Please use --output-aoti-package-path to create a .pt2 artifact instead.") | ||
print( | ||
"WARNING!! The path of compiling a dso is deprecated. Please use --output-aoti-package-path to create a .pt2 artifact instead." | ||
) | ||
export_for_server( | ||
model_to_dso, | ||
builder_args.device, | ||
|
@@ -446,11 +461,23 @@ def main(args): | |
|
||
if output_aoti_package_path: | ||
output_aoti_package_path = str(os.path.abspath(output_aoti_package_path)) | ||
print(f"Exporting model using AOT Inductor to {output_aoti_package_path}") | ||
|
||
if tokenizer_args is None: | ||
tokenizer_type = "0" | ||
elif tokenizer_args.is_sentencepiece: | ||
tokenizer_type = "2" # Corresponding to llama2 | ||
else: | ||
tokenizer_type = "3" # Corresponding to llama3 | ||
|
||
metadata = {"tokenizer_type": tokenizer_type} | ||
print( | ||
"Exporting model using AOT Inductor to " f"{output_aoti_package_path}." | ||
) | ||
export_for_server( | ||
model_to_aoti_package, | ||
builder_args.device, | ||
output_aoti_package_path, | ||
builder_args.dynamic_shapes, | ||
package=True, | ||
metadata=metadata, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can delete
metadata = metadata or {}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
general advice is to not have mutable structures as default args because they survive invocations =>
https://docs.python-guide.org/writing/gotchas/