Skip to content

Commit

Permalink
Fix executorch kv cache incompatibility with to_executorch lowering (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
dvorjackz authored Jan 10, 2025
1 parent 0c4053e commit 9666ee8
Show file tree
Hide file tree
Showing 8 changed files with 365 additions and 31 deletions.
10 changes: 8 additions & 2 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch

from executorch.devtools.etrecord import generate_etrecord
from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass

from executorch.extension.llm.export.builder import DType, LLMEdgeManager

Expand Down Expand Up @@ -775,6 +776,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
for partitioner in partitioners:
logging.info(f"--> {partitioner.__class__.__name__}")

additional_passes = []
if args.model in TORCHTUNE_DEFINED_MODELS:
additional_passes = [InitializedMutableBufferPass(["cache_pos"])]
if args.generate_etrecord:
if not builder_exported_to_edge.edge_manager:
raise ValueError("Unable to generate etrecord due to missing edge manager.")
Expand All @@ -789,7 +793,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
canonicalize_program(builder.edge_manager.exported_program())

builder = builder.to_executorch()
builder = builder.to_executorch(
passes=additional_passes,
)

# Generate ETRecord
if edge_manager_copy:
Expand All @@ -807,7 +813,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
canonicalize_program(builder.edge_manager.exported_program())

builder = builder.to_executorch()
builder = builder.to_executorch(passes=additional_passes)

if args.profile_memory:
generate_memory_trace(builder.export_program, "memory_profile.json")
Expand Down
18 changes: 15 additions & 3 deletions examples/models/llama3_2_vision/runner/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
TorchTuneLlamaRunner,
)

from executorch.extension.pybindings.portable_lib import _load_for_executorch
from executorch.extension.pybindings.portable_lib import (
_load_for_executorch_from_buffer,
)

# Load custom ops and quantized ops.
from executorch.extension.pybindings import portable_lib # noqa # usort: skip

# Note: import this after portable_lib
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
from executorch.extension.llm.custom_ops import custom_ops # noqa # usort: skip
from executorch.kernels import quantized # noqa


Expand All @@ -43,7 +45,17 @@ def __init__(self, args):
use_kv_cache=args.kv_cache,
vocab_size=params["vocab_size"],
)
self.model = _load_for_executorch(args.pte)
# Save the loaded model bytes to prevent data from going out of
# scope after the `with` and getting cleaned up by Python's
# garbage collector.
self.model_bytes = None
with open(args.pte, "rb") as f:
self.model_bytes = f.read()
# Need to use _load_for_executorch_from_buffer instead of
# _load_for_executorch because the latter uses MmapDataLoader,
# which doesn't have load_into() implemented, which is needed
# for loading initialized mutable buffers.
self.model = _load_for_executorch_from_buffer(self.model_bytes)
self.use_kv_cache = args.kv_cache

def forward(
Expand Down
9 changes: 7 additions & 2 deletions exir/emit/_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1575,7 +1575,8 @@ def _find_fqn_for_placeholder(
warnings.warn(
"Mutation on a buffer in the model is detected. ExecuTorch assumes "
"buffers that are mutated in the graph have a meaningless initial state, "
"only the shape and dtype will be serialized.",
"only the shape and dtype will be serialized, unless a pass which sets "
'meta["et_init_buffer"] to True such as InitializedMutableBufferPass is run.',
UserWarning,
stacklevel=1,
)
Expand All @@ -1602,6 +1603,7 @@ def placeholder(
"""
spec = self.node.meta["spec"]
constant_tag = self.node.meta.get("constant_tag", None)
initialize_buffer = self.node.meta.get("et_init_buffer", None)
is_user_input = True

if isinstance(target, str) and isinstance(spec, TensorSpec):
Expand Down Expand Up @@ -1655,7 +1657,10 @@ def placeholder(
spec.storage = real_tensor.untyped_storage()

# User inputs and mutable buffers are not constants, other buffers or parameters are.
spec.const = not (is_user_input or is_mutable_buffer)
if initialize_buffer and is_mutable_buffer:
spec.const = True
else:
spec.const = not (is_user_input or is_mutable_buffer)

evalue = (
self._tensor_spec_to_evalue(spec, constant_tag)
Expand Down
53 changes: 53 additions & 0 deletions exir/emit/test/test_emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import typing
import unittest
from contextlib import contextmanager
from copy import deepcopy
from typing import List, Optional, Tuple

import executorch.exir as exir
Expand All @@ -31,6 +32,7 @@
from executorch.exir.error import InternalError
from executorch.exir.passes import MemoryPlanningPass
from executorch.exir.passes.constant_prop_pass import constant_prop_pass
from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
from executorch.exir.print_program import pretty_print, print_program # noqa
from executorch.exir.schema import (
Expand All @@ -56,6 +58,7 @@
from executorch.extension.pybindings.portable_lib import (
_load_for_executorch_from_buffer,
)
from executorch.runtime import Runtime

from functorch.experimental import control_flow
from torch import nn
Expand Down Expand Up @@ -243,6 +246,56 @@ def forward(self, x):
)
self.assertIsInstance(program.execution_plan[0].values[outputs[6]].val, Null)

def test_initialized_mutable_buffer(self):
"""Test that mutable buffers can hold meaningful initialized state."""

class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
# Mutable buffer with non-empty initial state.
self.register_buffer("cache_pos", torch.arange(0, 10))

def forward(self, x):
self.cache_pos.add_(1)
return self.cache_pos

m = TestModule()
example_inputs = (torch.ones(10),)
ep = torch.export.export(m, example_inputs)
edge = to_edge(
ep,
compile_config=EdgeCompileConfig(
_check_ir_validity=False,
),
)

# Save a copy of the edge program since to_executorch is
# stateful to some degree.
edge_copy = deepcopy(edge)
et_config = ExecutorchBackendConfig(
passes=[InitializedMutableBufferPass(["cache_pos"])],
)
et_program_init_pass = edge.to_executorch(config=et_config)
et_program_regular = edge_copy.to_executorch()

runtime = Runtime.get()
program_init_pass = runtime.load_program(et_program_init_pass.buffer)
method_init_pass = program_init_pass.load_method("forward")

program_regular = runtime.load_program(et_program_regular.buffer)
method_regular = program_regular.load_method("forward")

# Test that the mutable buffer is initialized.
torch.allclose(
method_init_pass.execute((example_inputs))[0], torch.arange(1, 11)
)
# Test that the mutable buffer is uninitialized and starts with default zeros,
# we test equality with torch.ones because of the mutation += 1 in the model forward.
torch.allclose(
method_regular.execute((example_inputs))[0],
torch.ones(10, dtype=torch.int64),
)

def test_int_list_input(self):
class M(torch.nn.Module):
def forward(self, x, y, z):
Expand Down
32 changes: 32 additions & 0 deletions exir/passes/init_mutable_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from typing import List

from executorch.exir.pass_base import ExportPass


class InitializedMutableBufferPass(ExportPass):
"""
If a buffer has a name that within a specified list, set meta["et_init_buffer"]
to True, which provides the mutable buffer with an initialized state.
As an example, a module with `self.register_buffer("cache_pos", torch.arange(10))`
when patterns = ["cache_pos"] would have its initial state set instead of being
left uninitialized by default.
"""

def __init__(self, patterns: List[str]) -> None:
super().__init__()
self.patterns = patterns

def placeholder(self, name: str, arg, meta):
for pattern in self.patterns:
if pattern in name:
meta["et_init_buffer"] = True

return super().placeholder(name, arg, meta)
23 changes: 15 additions & 8 deletions extension/llm/export/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from executorch.exir.backend.utils import format_delegated_graph
from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig

from executorch.exir.pass_manager import PassType
from executorch.exir.passes import MemoryPlanningPass
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
Expand Down Expand Up @@ -415,21 +416,27 @@ def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManag

return self

def to_executorch(self) -> "LLMEdgeManager":
def to_executorch(
self, passes: Optional[List[PassType]] = None
) -> "LLMEdgeManager":
"""
Lower the model to executorch and get an ExecutorchProgram.
"""
assert self.edge_manager, "Need to run export_to_edge() first"
to_executorch_passes = [
# If there are Linear operations left in the graph, let's execute
# them with the optimized op_linear rather than materializing a
# transpose followed by a regular op_mm.
ConvertToLinearPass(),
QuantFusionPass(),
]
if passes:
to_executorch_passes.extend(passes)

self.export_program = self.edge_manager.to_executorch(
ExecutorchBackendConfig(
extract_delegate_segments=True,
passes=[
# If there are Linear operations left in the graph, let's execute
# them with the optimized op_linear rather than materializing a
# transpose followed by a regular op_mm.
ConvertToLinearPass(),
QuantFusionPass(),
],
passes=to_executorch_passes,
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
)
Expand Down
34 changes: 18 additions & 16 deletions extension/llm/modules/test/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import torch
from executorch.exir import EdgeCompileConfig, to_edge

from executorch.exir.capture._config import ExecutorchBackendConfig
from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass
from executorch.extension.llm.modules.attention import (
MultiHeadAttention as ETMultiHeadAttention,
)
Expand Down Expand Up @@ -114,7 +116,7 @@ def test_attention_eager(self):
et_res = self.et_mha(self.x, self.x) # Self attention.
tt_res = self.tt_mha(self.x, self.x) # Self attention.

self.assertTrue(torch.allclose(et_res, tt_res))
assert_close(et_res, tt_res)
self.et_mha.reset_cache()
self.tt_mha.reset_cache()

Expand All @@ -125,7 +127,7 @@ def test_attention_eager(self):
self.x, self.x, input_pos=self.input_pos
) # Self attention with input pos.

self.assertTrue(torch.allclose(et_res, tt_res))
assert_close(et_res, tt_res)

# test kv cache read. Input pos can be [10, 11, ..., 19]
next_input_pos = torch.arange(10, 20).unsqueeze(0)
Expand Down Expand Up @@ -187,9 +189,8 @@ def test_attention_aoti(self):

def test_attention_executorch(self):
# Self attention.
# TODO: Fix kv cache
# self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
# self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)

with torch.no_grad():
et_mha_ep = torch.export.export(
Expand All @@ -202,9 +203,15 @@ def test_attention_executorch(self):
et_program = to_edge(
et_mha_ep,
compile_config=EdgeCompileConfig(
_core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg]
_core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg],
_check_ir_validity=False,
),
).to_executorch()
).to_executorch(
config=ExecutorchBackendConfig(
passes=[InitializedMutableBufferPass(["cache_pos"])],
)
)

runtime = Runtime.get()
program = runtime.load_program(et_program.buffer)
method = program.load_method("forward")
Expand All @@ -219,28 +226,23 @@ def test_attention_torch_cond_eager(self):
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)

# mask
mask = self.causal_mask[self.input_pos, :]
# First run
# First run.
et_res = self.et_mha(
self.x, self.x, mask=mask, input_pos=self.input_pos
) # Self attention with input pos.
tt_res = self.tt_mha(
self.x, self.x, mask=mask, input_pos=self.input_pos
) # Self attention with input pos.

self.assertTrue(torch.allclose(et_res, tt_res))
assert_close(et_res, tt_res)

# Second run test kv cache read. Input pos is [10, 11, ..., 19]
next_input_pos = torch.arange(10, 20).unsqueeze(0)

empty_y = torch.full_like(self.x, torch.nan)
mask = self.causal_mask[next_input_pos, :]
et_res = self.et_mha(
self.x, empty_y, mask=mask, input_pos=next_input_pos
) # Self attention with input pos.
tt_res = self.tt_mha(
self.x, None, mask=mask, input_pos=next_input_pos
) # Self attention with input pos.
et_res = self.et_mha(self.x, empty_y, mask=mask, input_pos=next_input_pos)
tt_res = self.tt_mha(self.x, None, mask=mask, input_pos=next_input_pos)

assert_close(et_res, tt_res)
Loading

0 comments on commit 9666ee8

Please sign in to comment.