diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 84b174bba..93f60f8f8 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -14,15 +14,12 @@ from sharktank.layers import * from sharktank.types import * -# TODO: Should be using a base class with the protocol supported. from ..models.llama.llama import LlamaModelConfig, PagedLlamaModelV1 -from ..models.llama.sharding import shard_theta from ..models.mixtral.mixtral import * from ..models.grok.grok import * -from .. import ops -def main(): +def main(raw_args: list[str] | None = None): from ..utils import cli parser = cli.create_parser() @@ -60,7 +57,7 @@ def main(): choices=["decomposed", "torch"], ) - args = cli.parse(parser) + args = cli.parse(parser, args=raw_args) dataset_type = cli.get_input_data_files(args) dataset_type = "irpa" if "irpa" in dataset_type else "gguf" dataset = cli.get_input_dataset(args) @@ -110,7 +107,7 @@ def generate_params_json(hp, prefill_bs: list[int], decode_bs: list[int]): fxb = FxProgramsBuilder(model) - def setup_cache(model, shard_count): + def setup_cache(model): if model.config.kv_cache_type == "paged": cache_state = model.cache.allocate( page_count=hp.context_length // llama_config.block_seq_stride @@ -161,7 +158,7 @@ def generate_batch_prefill(bs: int): sl_dim = llama_config.block_seq_stride * block_dim cache, cache_shard_dim, cache_dynamic_shapes, arg_affinities = setup_cache( - model, llama_config.tensor_parallelism_size + model ) # We need to offset the indices for the cache @@ -234,7 +231,7 @@ def generate_batch_decode(bs: int): cache_shard_dim, cache_dynamic_shapes, arg_affinities, - ) = setup_cache(model, llama_config.tensor_parallelism_size) + ) = setup_cache(model) # We need to offset the indices for the cache arg_affinities = {key + 4: arg_affinities[key] for key in arg_affinities} diff --git a/sharktank/sharktank/layers/kv_cache.py b/sharktank/sharktank/layers/kv_cache.py index d7ade43a7..fa6ba587b 100644 --- a/sharktank/sharktank/layers/kv_cache.py +++ b/sharktank/sharktank/layers/kv_cache.py @@ -300,7 +300,7 @@ def shard_state( """Shard an unsharded state. We can't just split the slab on the sub page dims. First it needs to be reinterpreted into the actual shape. - The split the head dimension, then flatten each shard. + Then split the head dimension, then flatten each shard. This is a work-around for the lack of block-cyclic sharded tensor type.""" if self.shard_count == 1: return state @@ -324,6 +324,9 @@ def shard_state( flat_sharded_page_table = SplitPrimitiveTensor(ts=shards, shard_dim=1) return [flat_sharded_page_table] + def unshard_state(self, state: list[SplitPrimitiveTensor]) -> list[torch.Tensor]: + return [ops.unshard(self.unflatten_page_table(state)).flatten(start_dim=1)] + @property def pad_sequence_stride(self) -> int: return self.block_seq_stride diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index 656b4432b..d1cacefd1 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -186,29 +186,6 @@ def decode( self._assert_device(start_positions) self._assert_device(*cache_state, dtype=self.activation_dtype) - if self.config.tensor_parallelism_size > 1: - if not isinstance(tokens, ReplicatedTensor): - tokens = ops.replicate( - tokens, count=self.config.tensor_parallelism_size - ) - if not isinstance(attention_mask, ReplicatedTensor): - attention_mask = ops.replicate( - attention_mask, count=self.config.tensor_parallelism_size - ) - if not isinstance(start_positions, ReplicatedTensor): - start_positions = ops.replicate( - start_positions, count=self.config.tensor_parallelism_size - ) - if not isinstance(seq_block_ids, ReplicatedTensor): - seq_block_ids = ops.replicate( - seq_block_ids, count=self.config.tensor_parallelism_size - ) - # If the user provided unsharded arguments they probably want - # an unsharded result as well. - unshard_result = True - else: - unshard_result = False - bs, _ = tokens.shape # Precompute a position based mask for computing rope embeddings # as it is the same for all blocks. diff --git a/sharktank/sharktank/utils/testing.py b/sharktank/sharktank/utils/testing.py index 7b91b3a13..feebb25cd 100644 --- a/sharktank/sharktank/utils/testing.py +++ b/sharktank/sharktank/utils/testing.py @@ -14,9 +14,13 @@ from typing import Any, Callable from operator import eq from collections.abc import Iterable +import pytest +from sharktank.utils.tokenizer import InferenceTokenizer from ..types import * +longrun = pytest.mark.skipif("not config.getoption('longrun')") + # Range of torch.rand() is [0,1) # Range of torch.rand() * 2 - 1 is [-1, 1), includes negative values def make_rand_torch(shape, dtype=torch.float32): @@ -31,6 +35,16 @@ def tearDown(self): shutil.rmtree(self._temp_dir, ignore_errors=True) +@pytest.mark.usefixtures("path_prefix") +class PathPrefixTestBase(TempDirTestBase): + """Creates a temporary directory and uses it if a path prefix is not given.""" + + def setUp(self): + super().setUp() + if self.path_prefix is None: + self.path_prefix = f"{self._temp_dir}/" + + class MainRunnerTestBase(TempDirTestBase): """Performs an in-process test of a `main(args)` func.""" @@ -54,6 +68,25 @@ def assertFileWritten(self, p: Path): self.assertGreater(p.stat().st_size, 0, msg=f"Expected file {p} had zero size") +class ModuloTokenizer(InferenceTokenizer): + """A tokenizer used for testing where we take a modulo of each character. + Guarantees that we are producing tokens of up to the max token ID.""" + + def __init__(self, vocabulary_size: int): + self.vocabulary_size = vocabulary_size + + def _encode(self, texts: list[str], add_start_token: bool) -> list[list[int]]: + return [ + [ord(character) % self.vocabulary_size for character in text] + for text in texts + ] + + def _decode(self, tokens: list[list[int]]) -> list[str]: + return [ + "".join([chr(token) for token in prompt_tokens]) for prompt_tokens in tokens + ] + + @contextlib.contextmanager def temporary_directory(identifier: str): """Returns a context manager TemporaryDirectory suitable for testing. diff --git a/sharktank/sharktank/utils/tokenizer.py b/sharktank/sharktank/utils/tokenizer.py index b459c706a..597533373 100644 --- a/sharktank/sharktank/utils/tokenizer.py +++ b/sharktank/sharktank/utils/tokenizer.py @@ -75,7 +75,7 @@ def pad_tokens( return token_ids, lengths @abstractmethod - def _encode(self, texts: list[str]) -> list[list[int]]: + def _encode(self, texts: list[str], add_start_token: bool) -> list[list[int]]: ... @abstractmethod diff --git a/sharktank/tests/models/llama/sharded_llama_test.py b/sharktank/tests/models/llama/sharded_llama_test.py index 386061731..ee50d3d3a 100644 --- a/sharktank/tests/models/llama/sharded_llama_test.py +++ b/sharktank/tests/models/llama/sharded_llama_test.py @@ -4,42 +4,282 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -import unittest +from copy import deepcopy +from iree.compiler import compile_file, InputType +from typing import Any +import functools +import os import pytest -from typing import Any, List, Tuple, OrderedDict +import torch + +from sharktank.examples import export_paged_llm_v1 +from sharktank.examples.sharding import shard_llm_dataset +from sharktank.examples.paged_llm_v1 import TorchGenerator from sharktank.models.llama.llama import LlamaModelConfig, PagedLlamaModelV1 -import sharktank.ops as ops -from sharktank.types import unbox_tensor, Dataset, UnreducedTensor, SplitPrimitiveTensor -from sharktank.models.llama.testing import make_random_llama_theta -from sharktank.utils.testing import skip -from sharktank.models.llama.sharding import shard_theta +from sharktank.layers import CausalLMModelABC from sharktank.layers.configs import LlamaHParams +from sharktank.layers.testing import CausalLMIreeModel +from sharktank.models.llama.sharding import shard_theta +from sharktank.models.llama.testing import make_random_llama_theta +from sharktank.types import ( + AnyTensor, + InferenceTensor, + DefaultPrimitiveTensor, + Dataset, + dtype_to_serialized_name, +) from sharktank.utils.math import round_up_to_multiple_of -from sharktank.utils import iterables_equal from sharktank.utils.iree import ( get_iree_devices, load_iree_module, - run_iree_module_function, - prepare_iree_module_function_args, - call_torch_module_function, - iree_to_torch, ) -from sharktank.export import export as sharktank_export -import tempfile -import torch -from copy import deepcopy -from iree.turbine.aot import FxProgramsBuilder, export -import iree.runtime -import numpy as np -import os +from sharktank.utils.testing import PathPrefixTestBase, ModuloTokenizer, longrun +from sharktank.utils.tokenizer import load_tokenizer, InferenceTokenizer +import sharktank.ops as ops + + +AnyTokenizer = Any + + +def set_float_dtype(tensor: InferenceTensor, dtype: torch.dtype) -> InferenceTensor: + if isinstance(tensor, DefaultPrimitiveTensor) and tensor.dtype.is_floating_point: + return DefaultPrimitiveTensor( + name=tensor.name, data=ops.to(tensor, dtype=dtype) + ) + assert False, "Unsupported tensor type" + + +def shard_dataset( + path: str, + output_path: str, + tensor_parallelism_size: int, + intermediates_caching: bool, +): + if not intermediates_caching or not os.path.exists(output_path): + if path.endswith(".gguf"): + dataset_arg = f"--gguf-file={path}" + elif path.endswith(".irpa"): + dataset_arg = f"--irpa-file={path}" + else: + raise ValueError(f'Invalid dataset filename "{dataset_arg}"') + shard_llm_dataset.main( + [ + f"--tensor-parallelism-size={tensor_parallelism_size}", + dataset_arg, + f"--output-irpa-file={output_path}", + ] + ) + + +def compile_iree_module( + intermediates_caching: bool, + config: LlamaModelConfig, + dataset_path: str, + batch_size: int, + target_device: str, + output_mlir_path: str, + output_module_path: str, + output_config_path: str, +): + if not intermediates_caching or not os.path.exists(output_module_path): + export_paged_llm_v1.main( + [ + f"--output-mlir={output_mlir_path}", + f"--irpa-file={dataset_path}", + f"--output-config={output_config_path}", + f"--bs={batch_size}", + f"--block-seq-stride={config.block_seq_stride}", + f"--attention-dtype={dtype_to_serialized_name(config.attention_dtype)}", + f"--activation-dtype={dtype_to_serialized_name(config.activation_dtype)}", + ] + ) + compiler_extra_args = [ + f"--iree-hal-target-device={target_device}[{i}]" + for i in range(config.tensor_parallelism_size) + ] + + compile_file( + output_mlir_path, + input_type=InputType.TORCH, + output_file=output_module_path, + extra_args=compiler_extra_args, + ) -@pytest.mark.usefixtures("caching", "path_prefix") -class ShardedLlamaTest(unittest.TestCase): +def assert_close_cache_state( + actual: list[torch.Tensor], + expected: list[torch.Tensor], +): + torch.testing.assert_close( + actual[0].to(dtype=expected[0].dtype), expected[0], atol=1e-3, rtol=0 + ) + + +def assert_close_logits( + actual: torch.Tensor, + expected: torch.Tensor, +): + actual_probabilities = torch.softmax(actual, dim=1) + expected_probabilities = torch.softmax(expected, dim=1) + torch.testing.assert_close( + actual_probabilities.to(dtype=expected_probabilities.dtype), + expected_probabilities, + atol=1e-3, + rtol=0, + ) + + +def raise_multiple(errors): + if not errors: # list emptied, recursion ends + return + try: + raise errors.pop() # pop removes list entries + finally: + raise_multiple(errors) # recursion + + +def assert_close_post_call( + actual_logits: torch.Tensor, + expected_logits: torch.Tensor, + actual_cache_state: list[AnyTensor], + expected_cache_state: list[AnyTensor], +): + errors = [] + try: + assert_close_logits(actual_logits, expected_logits) + except Exception as ex: + errors.append(ex) + try: + assert_close_cache_state(actual_cache_state, expected_cache_state) + except Exception as ex: + errors.append(ex) + raise_multiple(errors) + + +def compare_models( + target_model: CausalLMModelABC, + reference_model: CausalLMModelABC, + tokenizer: InferenceTokenizer, + cache_page_count: int, + prompts: list[str], +): + generator = TorchGenerator( + target_model, tokenizer, page_cache_size=cache_page_count + ) + reference_generator = TorchGenerator( + reference_model, tokenizer, page_cache_size=cache_page_count + ) + batch = generator.begin_batch(prompts) + reference_batch = reference_generator.begin_batch(prompts) + + # Init the cache and copy it to both the target and the reference. + unsharded_reference_cache_state = reference_model.cache.paged.unshard_state( + reference_batch.cache_state + ) + torch.full( + size=unsharded_reference_cache_state[0].shape, + fill_value=0, + out=unsharded_reference_cache_state[0], + ) + reference_batch.cache_state[0][...] = reference_model.cache.paged.shard_state( + unsharded_reference_cache_state + )[0] + batch.cache_state[0][...] = target_model.cache.paged.shard_state( + unsharded_reference_cache_state + )[0] + + batch.prefill() + reference_batch.prefill() + assert_close_post_call( + actual_logits=batch.logits, + expected_logits=reference_batch.logits, + actual_cache_state=target_model.cache.paged.unshard_state(batch.cache_state), + expected_cache_state=reference_batch.cache_state, + ) + + batch.decode() + reference_batch.decode() + assert_close_post_call( + actual_logits=batch.logits, + expected_logits=reference_batch.logits, + actual_cache_state=target_model.cache.paged.unshard_state(batch.cache_state), + expected_cache_state=reference_batch.cache_state, + ) + + +def run_test_compare_iree_against_torch( + path_prefix: str, + intermediates_caching: bool, + torch_dataset_path: str, + torch_config: LlamaModelConfig, + iree_dataset_path: str, + iree_config: LlamaModelConfig, + iree_target_device: str, + iree_driver: str, + tokenizer: InferenceTokenizer, + prompts: list[str], + cache_page_count: int, +): + iree_module_path = f"{path_prefix}program.vmfb" + compile_iree_module( + intermediates_caching=intermediates_caching, + config=iree_config, + dataset_path=iree_dataset_path, + batch_size=len(prompts), + target_device=iree_target_device, + output_mlir_path=f"{path_prefix}program.mlir", + output_module_path=iree_module_path, + output_config_path=f"{path_prefix}program_config.json", + ) + iree_devices = get_iree_devices( + driver=iree_driver, + device_count=iree_config.tensor_parallelism_size, + ) + iree_module, vm_context, vm_instance = load_iree_module( + module_path=iree_module_path, + devices=iree_devices, + parameters_path=iree_dataset_path, + ) + iree_model = CausalLMIreeModel( + batch_size=len(prompts), + config=iree_config, + vm_context=vm_context, + iree_driver=iree_driver, + iree_module=iree_module, + iree_devices=iree_devices, + ) + + torch_dataset = Dataset.load(torch_dataset_path, mmap=False) + torch_model = PagedLlamaModelV1(theta=torch_dataset.root_theta, config=torch_config) + + compare_models( + target_model=iree_model, + reference_model=torch_model, + tokenizer=tokenizer, + cache_page_count=cache_page_count, + prompts=prompts, + ) + + +@pytest.mark.usefixtures("caching") +class ShardedLlamaTestBase(PathPrefixTestBase): def setUp(self): + super().setUp() torch.random.manual_seed(123456) - self.dtype = torch.float32 - torch.set_default_dtype(self.dtype) + self.intermediates_caching = self.caching + self.prompts = [ + "The sky is blue", + "The night is dark", + "Linguistics is the study of", + ] + + +class ShardedLlamaToySizedTest(ShardedLlamaTestBase): + def setUp(self): + super().setUp() + self.reference_dtype = torch.float64 + self.target_dtype = torch.float32 + torch.set_default_dtype(self.reference_dtype) self.batch_size = 3 self.attention_head_count_kv = 4 self.attention_head_count = self.attention_head_count_kv * 5 @@ -47,10 +287,16 @@ def setUp(self): self.rope_dimension_count = 7 * 2 self.attn_head_dim = self.rope_dimension_count self.block_seq_stride = 13 + self.context_length = round_up_to_multiple_of( + functools.reduce(max, [len(prompt) for prompt in self.prompts]), + self.block_seq_stride, + ) + # Make this large enough to make torch.export.Dim happy. + self.context_length = max(self.context_length, 4 * self.block_seq_stride) self.cache_page_count = 11 self.config = LlamaModelConfig( hp=LlamaHParams( - context_length=self.block_seq_stride * 2, + context_length=self.context_length, embedding_length=self.attention_head_count * self.attn_head_dim, block_count=3, feed_forward_length=23, @@ -65,342 +311,157 @@ def setUp(self): model_arch="llama", ), block_seq_stride=self.block_seq_stride, - activation_dtype=self.dtype, - attention_dtype=self.dtype, + activation_dtype=self.reference_dtype, + attention_dtype=self.reference_dtype, + static_tables=False, ) self.sharded_config = deepcopy(self.config) self.sharded_config.tensor_parallelism_size = 2 + self.sharded_config.activation_dtype = self.target_dtype + self.sharded_config.attention_dtype = self.target_dtype + self.theta = make_random_llama_theta( config=self.config, vocab_size=self.vocabulary_size, ) - self.prefill_seq_lens = torch.tensor( - [14, 9, self.block_seq_stride - 1], dtype=torch.int64 - ) + self.theta.rename_tensors_to_paths() - def make_prefill_args(self, model: PagedLlamaModelV1) -> OrderedDict[str, Any]: - batch_seq_len = round_up_to_multiple_of( - int(torch.max(self.prefill_seq_lens)), model.cache.pad_sequence_stride - ) - token_ids = torch.randint( - low=0, - high=self.vocabulary_size, - size=[self.batch_size, batch_seq_len], - dtype=torch.int32, - ) - attention_mask = model.attention_mask( - model.input_mask(self.prefill_seq_lens, batch_seq_len) - ) - seq_block_ids = torch.arange( - self.batch_size * batch_seq_len // self.config.block_seq_stride - ).view(self.batch_size, -1) - cache_state = model.cache.paged.allocate(page_count=self.cache_page_count) - cache_state = [torch.rand_like(cache_state[0])] - return OrderedDict( - [ - ("tokens", token_ids), - ("attention_mask", attention_mask), - ("seq_block_ids", seq_block_ids), - ("cache_state", cache_state), - ] - ) + self.tokenizer = ModuloTokenizer(self.vocabulary_size) - def make_equal_unsharded_and_sharded_prefill_args( - self, model: PagedLlamaModelV1, sharded_model: PagedLlamaModelV1 - ) -> Tuple[OrderedDict[str, Any], OrderedDict[str, Any]]: - prefill_kwargs = self.make_prefill_args(model) - sharded_cache_state = sharded_model.cache.paged.allocate( - page_count=self.cache_page_count - ) - assert iterables_equal( - prefill_kwargs["cache_state"][0].shape, sharded_cache_state[0].shape - ) - sharded_prefill_kwargs = deepcopy(prefill_kwargs) - sharded_cache_state = sharded_model.cache.paged.shard_state( - sharded_prefill_kwargs["cache_state"] - ) - sharded_prefill_kwargs["cache_state"] = sharded_cache_state - - sharding = sharded_model.config.tensor_parallelism_size - for k in sharded_prefill_kwargs: - if k == "cache_state": - continue - sharded_prefill_kwargs[k] = ops.replicate( - sharded_prefill_kwargs[k], count=sharding - ) - - return prefill_kwargs, sharded_prefill_kwargs - - def make_decode_args(self, model: PagedLlamaModelV1) -> OrderedDict[str, Any]: - start_positions = self.prefill_seq_lens.clone() - seq_lens = self.prefill_seq_lens + 1 - batch_seq_len = round_up_to_multiple_of( - int(torch.max(seq_lens)), model.cache.pad_sequence_stride + def testCompareTensorParallelToUnsharded(self): + """Run a sharded variant of a toy model size and compare it against the + unsharded variant.""" + sharded_theta = self.theta.transform( + functools.partial(set_float_dtype, dtype=self.target_dtype) ) - decode_token_ids = torch.randint( - low=0, - high=self.vocabulary_size, - size=[self.batch_size, 1], - dtype=torch.int32, + sharded_theta = shard_theta(sharded_theta, self.sharded_config) + sharded_model = PagedLlamaModelV1(sharded_theta, self.sharded_config) + reference_model = PagedLlamaModelV1(self.theta, self.config) + compare_models( + target_model=sharded_model, + reference_model=reference_model, + tokenizer=self.tokenizer, + prompts=self.prompts, + cache_page_count=self.cache_page_count, ) - attention_mask = model.decode_attention_mask( - model.input_mask(seq_lens, batch_seq_len) + + def testCompareTensorParallelWithIreeToUnsharded(self): + """Test exporting to MLIR and compiling with IREE the sharded Llama model. + Test numerical accuracy of the IREE module against PyTorch.""" + + dataset = Dataset( + properties=self.config.hp.to_gguf_props(), root_theta=self.theta ) - seq_block_ids = torch.arange( - self.batch_size * batch_seq_len // self.config.block_seq_stride - ).view(self.batch_size, -1) - cache_state = model.cache.paged.allocate(page_count=self.cache_page_count) - cache_state = [torch.rand_like(cache_state[0])] - return OrderedDict( - [ - ("tokens", decode_token_ids), - ("attention_mask", attention_mask), - ("start_positions", start_positions), - ("seq_block_ids", seq_block_ids), - ("cache_state", cache_state), - ] + torch_dataset_path = f"{self.path_prefix}torch-reference-dataset.irpa" + if not self.intermediates_caching or not os.path.exists(torch_dataset_path): + dataset.save(torch_dataset_path) + + iree_unsharded_theta = self.theta.transform( + functools.partial(set_float_dtype, dtype=self.target_dtype) + ) + iree_unsharded_dataset = Dataset( + properties=self.sharded_config.hp.to_gguf_props(), + root_theta=iree_unsharded_theta, + ) + iree_usharded_dataset_path = f"{self.path_prefix}iree-dataset-unsharded.irpa" + if not self.intermediates_caching or not os.path.exists( + iree_usharded_dataset_path + ): + iree_unsharded_dataset.save(iree_usharded_dataset_path) + + iree_dataset_path = f"{self.path_prefix}iree-dataset.irpa" + + shard_dataset( + path=iree_usharded_dataset_path, + output_path=iree_dataset_path, + tensor_parallelism_size=self.sharded_config.tensor_parallelism_size, + intermediates_caching=self.intermediates_caching, ) - def make_equal_unsharded_and_sharded_decode_args( - self, model: PagedLlamaModelV1, sharded_model: PagedLlamaModelV1 - ) -> Tuple[OrderedDict[str, Any], OrderedDict[str, Any]]: - decode_kwargs = self.make_decode_args(model) - sharded_decode_kwargs = deepcopy(decode_kwargs) - sharded_decode_kwargs["cache_state"] = sharded_model.cache.paged.shard_state( - sharded_decode_kwargs["cache_state"] + run_test_compare_iree_against_torch( + path_prefix=self.path_prefix, + intermediates_caching=self.intermediates_caching, + torch_dataset_path=torch_dataset_path, + torch_config=self.config, + iree_dataset_path=iree_dataset_path, + iree_config=self.sharded_config, + iree_target_device="llvm-cpu", + iree_driver="local-task", + tokenizer=self.tokenizer, + prompts=self.prompts, + cache_page_count=self.cache_page_count, ) - sharding = sharded_model.config.tensor_parallelism_size - for k in sharded_decode_kwargs: - if k == "cache_state": - continue - sharded_decode_kwargs[k] = ops.replicate( - sharded_decode_kwargs[k], count=sharding - ) - return decode_kwargs, sharded_decode_kwargs +@pytest.mark.usefixtures("get_model_path") +class Llama38BFp16Tp8Test(ShardedLlamaTestBase): + def setUp(self): + super().setUp() + tokenizer_path = self.llama3_8b_tokenizer + self.tokenizer = load_tokenizer(tokenizer_path.parent) - def testCompareToySizedModelToUnsharded(self): - """Run a sharded variant of a toy model size and compare it against the - unsharded variant.""" - model = PagedLlamaModelV1(self.theta, self.config) - sharded_theta = shard_theta(self.theta, self.sharded_config) - sharded_model = PagedLlamaModelV1(sharded_theta, self.sharded_config) + self.reference_dtype = torch.float64 + self.dataset_path = str(self.llama3_8b_f16_model) + self.batch_size = 4 + self.cache_page_count = 8192 + tensor_parallelism_size = 8 - # Verify prefill step. - ( - prefill_kwargs, - sharded_prefill_kwargs, - ) = self.make_equal_unsharded_and_sharded_prefill_args(model, sharded_model) - - expected_prefill_result = model.prefill(**prefill_kwargs) - sharded_prefill_result = sharded_model.prefill(**sharded_prefill_kwargs) - sharded_prefill_result = ops.unshard(sharded_prefill_result) - # The errors are quite high, but for float64 both errors drop to < 1e-12. - # The numerics are probably correct. - torch.testing.assert_close( - sharded_prefill_result, expected_prefill_result, atol=1e-3, rtol=1e-2 - ) - expected_cache_state = prefill_kwargs["cache_state"][0] - actual_cache_state = ops.unshard( - sharded_model.cache.paged.unflatten_page_table( - sharded_prefill_kwargs["cache_state"] - ) - ).flatten(start_dim=1) - torch.testing.assert_close( - actual_cache_state, expected_cache_state, atol=1e-4, rtol=1e-1 - ) + dataset = Dataset.load(self.dataset_path) + self.theta = dataset.root_theta - # Verify decode step. - ( - decode_kwargs, - sharded_decode_kwargs, - ) = self.make_equal_unsharded_and_sharded_decode_args(model, sharded_model) - expected_decode_result = model.decode(**decode_kwargs) - sharded_decode_result = sharded_model.decode(**sharded_decode_kwargs) - sharded_decode_result = ops.unshard(sharded_decode_result) - torch.testing.assert_close( - sharded_decode_result, expected_decode_result, atol=1e-4, rtol=1e-5 + self.config = LlamaModelConfig( + hp=LlamaHParams.from_gguf_props(dataset.properties), + activation_dtype=self.reference_dtype, + attention_dtype=self.reference_dtype, + static_tables=False, ) - expected_decode_cache_state = decode_kwargs["cache_state"][0] - actual_decode_cache_state = ops.unshard( - sharded_model.cache.paged.unflatten_page_table( - sharded_decode_kwargs["cache_state"] - ) - ).flatten(start_dim=1) - # TODO: investigate why the Windows machine CI is producing a larger numerical - # error. - # The Ubuntu CI runs fine with default tolerances. - torch.testing.assert_close( - actual_decode_cache_state, expected_decode_cache_state, atol=1e-4, rtol=1e-4 + self.sharded_config = LlamaModelConfig( + hp=LlamaHParams.from_gguf_props(dataset.properties), + tensor_parallelism_size=tensor_parallelism_size, + static_tables=False, # Rely on the compiler for hoisting tables. ) - @skip( - ( - "Before this does not crash at all we need " - "https://github.com/iree-org/iree/pull/18663 merged." - ) + def tearDown(self): + # make sure we don't reference the memory mapped file. + del self.theta + super().tearDown() + + @longrun + @pytest.mark.xfail( + reason="Numerics are not close.", raises=AssertionError, strict=True ) - def testExportAndRunToySizedModelWithIree(self): + def testCompareTensorParallelWithIreeToUnsharded(self): """Test exporting to MLIR and compiling with IREE the sharded Llama model. Test numerical accuracy of the IREE module against PyTorch.""" - if self.path_prefix is not None: - self.runTestExportAndRunToySizedModelWithIree( - path_prefix=self.path_prefix, dump_enabled=True - ) - else: - with tempfile.TemporaryDirectory() as temp_dir: - self.runTestExportAndRunToySizedModelWithIree( - path_prefix=f"{temp_dir}/", dump_enabled=False - ) - - def runTestExportAndRunToySizedModelWithIree( - self, path_prefix: str, dump_enabled: bool - ): - sharded_theta = shard_theta(self.theta, self.sharded_config) - sharded_theta.rename_tensors_to_paths() - sharded_dataset = Dataset({}, sharded_theta) - sharded_parameters_path = f"{path_prefix}parameters.irpa" - sharded_dataset.save(sharded_parameters_path) - sharded_dataset = Dataset.load(sharded_parameters_path, mmap=False) - iree_driver = "local-task" - - model = PagedLlamaModelV1(self.theta, self.config) - sharded_model = PagedLlamaModelV1( - sharded_dataset.root_theta, self.sharded_config - ) - ( - _, - sharded_prefill_kwargs, - ) = self.make_equal_unsharded_and_sharded_prefill_args(model, sharded_model) - ( - _, - sharded_decode_kwargs, - ) = self.make_equal_unsharded_and_sharded_decode_args(model, sharded_model) - - iree_module_path = f"{path_prefix}program.vmfb" - if not self.caching or not os.path.exists(iree_module_path): - # Export and compile the IREE module. - sharded_fxb = FxProgramsBuilder(sharded_model) - - @sharktank_export( - fx_builder=sharded_fxb, - name="prefill", - kwargs=sharded_prefill_kwargs, - strict=False, - ) - def _(model, *args, **kwargs) -> torch.Tensor: - return model.prefill(*args, **kwargs) - - # TODO: remove strict=False when - # https://github.com/pytorch/pytorch/issues/136757 - # is resolved. - @sharktank_export( - fx_builder=sharded_fxb, - name="decode", - kwargs=sharded_decode_kwargs, - strict=False, - ) - def _(model, *args, **kwargs) -> torch.Tensor: - return model.decode(*args, **kwargs) - - output = export(sharded_fxb) - if dump_enabled: - output.save_mlir(f"{path_prefix}program.mlir") - output.session.set_flags( - *[ - f"--iree-hal-target-device=llvm-cpu[{i}]" - for i in range(self.sharded_config.tensor_parallelism_size) - ] - ) - output.compile( - save_to=iree_module_path, - target_backends=None, - ) - - iree_devices = get_iree_devices( - driver=iree_driver, - device_count=self.sharded_config.tensor_parallelism_size, - ) - iree_module, vm_context, vm_instance = load_iree_module( - module_path=iree_module_path, - devices=iree_devices, - parameters_path=sharded_parameters_path, - ) - - # Run prefill step. - prefill_iree_args = prepare_iree_module_function_args( - args=deepcopy(sharded_prefill_kwargs).values(), devices=iree_devices + reference_theta = self.theta.transform( + functools.partial(set_float_dtype, dtype=self.reference_dtype) ) - for i, arg in enumerate(prefill_iree_args): - np.save(f"{path_prefix}prefill_arg{i}.npy", arg.to_host()) - prefill_iree_result = run_iree_module_function( - args=prefill_iree_args, - function_name="prefill", - module=iree_module, - vm_context=vm_context, - driver=iree_driver, - trace_path_prefix=path_prefix if dump_enabled else None, - ) - prefill_iree_result = UnreducedTensor(ts=iree_to_torch(*prefill_iree_result)) - expected_prefill_result = call_torch_module_function( - module=sharded_model, - function_name="prefill", - kwargs=sharded_prefill_kwargs, - trace_path_prefix=f"{path_prefix}expected_" if dump_enabled else None, - ) - prefill_iree_cache_state_shards = prefill_iree_args[ - -self.config.tensor_parallelism_size - 1 : - ] - prefill_iree_cache_state = SplitPrimitiveTensor( - ts=iree_to_torch(*prefill_iree_cache_state_shards), - shard_dim=sharded_prefill_kwargs["cache_state"][0].shard_dim, + reference_dataset = Dataset( + properties=self.config.hp.to_gguf_props(), root_theta=reference_theta ) + reference_dataset_path = f"{self.path_prefix}torch-reference-dataset.irpa" + if not self.intermediates_caching or not os.path.exists(reference_dataset_path): + reference_dataset.save(reference_dataset_path) + target_dataset_path = f"{self.path_prefix}iree-dataset.irpa" - # Run decode step. - decode_iree_args = prepare_iree_module_function_args( - args=deepcopy(sharded_decode_kwargs).values(), devices=iree_devices - ) - decode_iree_result = run_iree_module_function( - args=decode_iree_args, - function_name="decode", - module=iree_module, - vm_context=vm_context, - driver=iree_driver, - trace_path_prefix=path_prefix if dump_enabled else None, - ) - decode_iree_result = UnreducedTensor(ts=iree_to_torch(*decode_iree_result)) - expected_decode_result = call_torch_module_function( - module=sharded_model, - function_name="decode", - kwargs=sharded_decode_kwargs, - trace_path_prefix=f"{path_prefix}expected_" if dump_enabled else None, - ) - decode_iree_cache_state_shards = decode_iree_args[ - -self.config.tensor_parallelism_size - 1 : - ] - decode_iree_cache_state = SplitPrimitiveTensor( - ts=iree_to_torch(*decode_iree_cache_state_shards), - shard_dim=sharded_decode_kwargs["cache_state"][0].shard_dim, + shard_dataset( + path=self.dataset_path, + output_path=target_dataset_path, + tensor_parallelism_size=self.sharded_config.tensor_parallelism_size, + intermediates_caching=self.intermediates_caching, ) - # Check IREE's numerical correctness against PyTorch. - # TODO: Although, not entirely wrong, investigate why this accuracy is that - # low for fp32 (atol=0.0011, rtol=0.013). - torch.testing.assert_close( - ops.unshard(prefill_iree_result), - ops.unshard(expected_prefill_result), - ) - torch.testing.assert_close( - ops.unshard(prefill_iree_cache_state), - ops.unshard(sharded_prefill_kwargs["cache_state"][0]), - ) - torch.testing.assert_close( - ops.unshard(decode_iree_result), - ops.unshard(expected_decode_result), - ) - torch.testing.assert_close( - ops.unshard(decode_iree_cache_state), - ops.unshard(sharded_decode_kwargs["cache_state"][0]), + run_test_compare_iree_against_torch( + path_prefix=self.path_prefix, + intermediates_caching=self.intermediates_caching, + torch_dataset_path=self.dataset_path, + torch_config=self.config, + iree_dataset_path=target_dataset_path, + iree_config=self.sharded_config, + iree_target_device="llvm-cpu", + iree_driver="local-task", + tokenizer=self.tokenizer, + prompts=self.prompts, + cache_page_count=self.cache_page_count, )