diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 5fb905dbf..ab1fd9dbb 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -10,6 +10,7 @@ from typing import Optional, Tuple from datetime import datetime import torch +from tp import maybe_init_dist, apply_tp import torchao import torch._dynamo.config import torch._inductor.config @@ -139,15 +140,21 @@ def encode_tokens(tokenizer, string, bos=True, device=default_device): tokens = [tokenizer.bos_id()] + tokens return torch.tensor(tokens, dtype=torch.int, device=device) -def _load_model(checkpoint_path, device, precision): +def _load_model(checkpoint_path, device, precision, use_tp): + use_cuda = 'cuda' in device + with torch.device('meta'): + model = Transformer.from_name(checkpoint_path.parent.name) checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) + print("Successfully loaded checkpoint") if "model" in checkpoint and "stories" in str(checkpoint_path): checkpoint = checkpoint["model"] - - model = Transformer.from_name(checkpoint_path.parent.name) model.load_state_dict(checkpoint, assign=True) - model = model.to(device=device, dtype=precision) + if use_tp: + print("Applying tensor parallel to model ...") + apply_tp(model) + + model = model.to(device=device, dtype=precision) return model.eval() B_INST, E_INST = "[INST]", "[/INST]" @@ -182,13 +189,20 @@ def main( tokenizer_path = checkpoint_path.parent / "tokenizer.model" assert tokenizer_path.is_file(), str(tokenizer_path) + global print + rank = maybe_init_dist() # Highlight: Initialize distributed training + use_tp = rank is not None + if use_tp: + if rank != 0: + # only print on rank 0 + print = lambda *args, **kwargs: None + print(f"Using device={device}") is_chat = "chat" in str(checkpoint_path) print("Loading model ...") t0 = time.time() - model = _load_model(checkpoint_path, device, precision) - + model = _load_model(checkpoint_path, device, precision, use_tp=use_tp) device_sync(device=device) # MKG print(f"Time to load model: {time.time() - t0:.02f} seconds") @@ -315,7 +329,7 @@ def callback(x): callback = lambda x : x t0 = time.perf_counter() import contextlib - if (i != num_samples - 1 or not profile): + if (i != num_samples - 1 or not profile) or (use_tp and rank != 0): prof = contextlib.nullcontext() else: torch.profiler._utils._init_for_cuda_graphs() @@ -337,7 +351,10 @@ def callback(x): print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") continue if hasattr(prof, "export_chrome_trace"): - prof.export_chrome_trace(f"{profile}.json") + if use_tp: + prof.export_chrome_trace(f"{profile}_rank_{rank}.json") + else: + prof.export_chrome_trace(f"{profile}.json") device_sync(device=device) # MKG t = time.perf_counter() - t0 @@ -434,4 +451,4 @@ def callback(x): main( args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k, args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result - ) + ) \ No newline at end of file diff --git a/torchao/_models/llama/tp.py b/torchao/_models/llama/tp.py new file mode 100644 index 000000000..4f3369c13 --- /dev/null +++ b/torchao/_models/llama/tp.py @@ -0,0 +1,162 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# copied from https://github.com/pytorch-labs/gpt-fast/blob/main/tp.py +import os +from typing import List, Optional + +import torch +import torch.distributed as dist +from torch import nn +if os.uname().sysname != "Darwin": + from torch.distributed import _functional_collectives as funcol +else: + # Distributed is not supported on MacOS + funcol = None + +from torchao._models.llama.model import Attention, FeedForward, Transformer +from torchao.quantization.GPTQ import WeightOnlyInt4Linear + + +def _get_rank() -> int: + return int(os.environ.get("LOCAL_RANK", "0")) + +def is_local(): + return _get_rank() == 0 + +def local_break(): + if is_local(): + breakpoint() + dist.barrier() + +def _get_world_size() -> int: + return int(os.environ.get("LOCAL_WORLD_SIZE", "1")) + +def maybe_init_dist() -> Optional[int]: + try: + # provided by torchrun + rank = _get_rank() + world_size = _get_world_size() + + if world_size < 2: + # too few gpus to parallelize, tp is no-op + return None + except KeyError: + # not run via torchrun, no-op + return None + + torch.cuda.set_device(rank) + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + return rank + + +def _apply_tp_linear(linear: nn.Linear, style: str, weight_splits: List[int] = []) -> None: + rank = _get_rank() + world_size = _get_world_size() + + # Linear's weight matrix is transposed, and is of shape + # (linear.out_features, linear.in_features) + dim_lookup = { + "colwise": (0, "out_features"), + "rowwise": (1, "in_features") + } + assert style in dim_lookup + shard_dim, size_attr = dim_lookup[style] + + # ensure we can shard evenly + assert getattr(linear, size_attr) % world_size == 0 + def shard(x, dim): + assert x.size(dim=dim) % world_size == 0 + return torch.tensor_split(x, world_size, dim=dim)[rank] + + def shard_qkv(qkv, dim, weight_splits): + q, k, v = qkv.split(weight_splits, dim=dim) + q = shard(q, dim) + k = shard(k, dim) + v = shard(v, dim) + return torch.cat((q,k,v), dim=dim) + + # shard + if weight_splits: + # attention + assert len(weight_splits) == 3 + + if isinstance(linear, WeightOnlyInt4Linear): + sharded_weight = shard_qkv(linear.weight, shard_dim, [i//8 for i in weight_splits]) + linear.scales_and_zeros = shard_qkv(linear.scales_and_zeros, 1 - shard_dim, weight_splits) + else: + sharded_weight = shard_qkv(linear.weight, shard_dim, weight_splits) + if hasattr(linear, "scales") and style == "colwise": + linear.scales = shard_qkv(linear.scales, 0, weight_splits) + else: + sharded_weight = shard(linear.weight, shard_dim) + if isinstance(linear, WeightOnlyInt4Linear): + linear.scales_and_zeros = shard(linear.scales_and_zeros, 1 - shard_dim) + if style == "rowwise": + assert linear.scales_and_zeros.shape[0] * 32 == sharded_weight.shape[1] * sharded_weight.shape[2] * sharded_weight.shape[3] + assert linear.scales_and_zeros.shape[1] == sharded_weight.shape[0] * 8 + if hasattr(linear, "scales") and style == "colwise": + linear.scales = shard(linear.scales, 0) + + # local_break() + linear.weight = nn.Parameter(sharded_weight, requires_grad=False) + setattr(linear, size_attr, getattr(linear, size_attr) // world_size) + + # shape info should still be synced + # assert linear.weight.shape == (linear.out_features, linear.in_features) + + +def _apply_tp_ffn(mlp: FeedForward) -> None: + assert hasattr(mlp, "w1") + assert hasattr(mlp, "w3") + assert hasattr(mlp, "w2") + + _apply_tp_linear(mlp.w1, "colwise") + _apply_tp_linear(mlp.w3, "colwise") + _apply_tp_linear(mlp.w2, "rowwise") + + world_size = _get_world_size() + def all_reduce_hook(module, input, output): + dist.all_reduce(output) + return output + mlp.register_forward_hook(all_reduce_hook) + + +def _apply_tp_attn(attn: Attention) -> None: + assert hasattr(attn, "wqkv") + assert hasattr(attn, "wo") + + kv_size = attn.n_local_heads * attn.head_dim + _apply_tp_linear(attn.wqkv, "colwise", [attn.dim, kv_size, kv_size]) + _apply_tp_linear(attn.wo, "rowwise") + + # overwrite + world_size = _get_world_size() + attn.n_head = attn.n_head // world_size + attn.dim = attn.dim // world_size + attn.head_dim = attn.dim // attn.n_head + attn.n_local_heads = attn.n_local_heads // world_size + def all_reduce_hook(module, input, output): + dist.all_reduce(output[0]) + return output + + attn.register_forward_hook(all_reduce_hook) + + +def _apply_tp_Transformer(Transformer: Transformer) -> None: + # overwrite config before Transformer.setup_cache is called + world_size = _get_world_size() + Transformer.config.n_head = Transformer.config.n_head // world_size + Transformer.config.dim = Transformer.config.dim // world_size + Transformer.config.n_local_heads = Transformer.config.n_local_heads // world_size + + +def apply_tp(model: Transformer) -> None: + _apply_tp_Transformer(model) + for block in model.layers: + # Apply to MLP + _apply_tp_ffn(block.feed_forward) + _apply_tp_attn(block.attention) \ No newline at end of file