Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cleaned up and tested tp support #976

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 26 additions & 9 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Optional, Tuple
from datetime import datetime
import torch
from tp import maybe_init_dist, apply_tp
import torchao
import torch._dynamo.config
import torch._inductor.config
Expand Down Expand Up @@ -139,15 +140,21 @@ def encode_tokens(tokenizer, string, bos=True, device=default_device):
tokens = [tokenizer.bos_id()] + tokens
return torch.tensor(tokens, dtype=torch.int, device=device)

def _load_model(checkpoint_path, device, precision):
def _load_model(checkpoint_path, device, precision, use_tp):
use_cuda = 'cuda' in device
with torch.device('meta'):
model = Transformer.from_name(checkpoint_path.parent.name)
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
print("Successfully loaded checkpoint")
if "model" in checkpoint and "stories" in str(checkpoint_path):
checkpoint = checkpoint["model"]

model = Transformer.from_name(checkpoint_path.parent.name)
model.load_state_dict(checkpoint, assign=True)
model = model.to(device=device, dtype=precision)

if use_tp:
print("Applying tensor parallel to model ...")
apply_tp(model)

model = model.to(device=device, dtype=precision)
return model.eval()

B_INST, E_INST = "[INST]", "[/INST]"
Expand Down Expand Up @@ -182,13 +189,20 @@ def main(
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), str(tokenizer_path)

global print
rank = maybe_init_dist() # Highlight: Initialize distributed training
use_tp = rank is not None
if use_tp:
if rank != 0:
# only print on rank 0
print = lambda *args, **kwargs: None

print(f"Using device={device}")
is_chat = "chat" in str(checkpoint_path)

print("Loading model ...")
t0 = time.time()
model = _load_model(checkpoint_path, device, precision)

model = _load_model(checkpoint_path, device, precision, use_tp=use_tp)

device_sync(device=device) # MKG
print(f"Time to load model: {time.time() - t0:.02f} seconds")
Expand Down Expand Up @@ -315,7 +329,7 @@ def callback(x):
callback = lambda x : x
t0 = time.perf_counter()
import contextlib
if (i != num_samples - 1 or not profile):
if (i != num_samples - 1 or not profile) or (use_tp and rank != 0):
prof = contextlib.nullcontext()
else:
torch.profiler._utils._init_for_cuda_graphs()
Expand All @@ -337,7 +351,10 @@ def callback(x):
print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
continue
if hasattr(prof, "export_chrome_trace"):
prof.export_chrome_trace(f"{profile}.json")
if use_tp:
prof.export_chrome_trace(f"{profile}_rank_{rank}.json")
else:
prof.export_chrome_trace(f"{profile}.json")
device_sync(device=device) # MKG
t = time.perf_counter() - t0

Expand Down Expand Up @@ -434,4 +451,4 @@ def callback(x):
main(
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result
)
)
162 changes: 162 additions & 0 deletions torchao/_models/llama/tp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# copied from https://github.com/pytorch-labs/gpt-fast/blob/main/tp.py
import os
from typing import List, Optional

import torch
import torch.distributed as dist
from torch import nn
if os.uname().sysname != "Darwin":
from torch.distributed import _functional_collectives as funcol
else:
# Distributed is not supported on MacOS
funcol = None

from torchao._models.llama.model import Attention, FeedForward, Transformer
from torchao.quantization.GPTQ import WeightOnlyInt4Linear
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are not usingWeightOnlyInt4Linear any more in torchao I think, is this just for GPTQ?



def _get_rank() -> int:
return int(os.environ.get("LOCAL_RANK", "0"))

def is_local():
return _get_rank() == 0

def local_break():
if is_local():
breakpoint()
dist.barrier()

def _get_world_size() -> int:
return int(os.environ.get("LOCAL_WORLD_SIZE", "1"))

def maybe_init_dist() -> Optional[int]:
try:
# provided by torchrun
rank = _get_rank()
world_size = _get_world_size()

if world_size < 2:
# too few gpus to parallelize, tp is no-op
return None
except KeyError:
# not run via torchrun, no-op
return None

torch.cuda.set_device(rank)
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
return rank


def _apply_tp_linear(linear: nn.Linear, style: str, weight_splits: List[int] = []) -> None:
rank = _get_rank()
world_size = _get_world_size()

# Linear's weight matrix is transposed, and is of shape
# (linear.out_features, linear.in_features)
dim_lookup = {
"colwise": (0, "out_features"),
"rowwise": (1, "in_features")
}
assert style in dim_lookup
shard_dim, size_attr = dim_lookup[style]

# ensure we can shard evenly
assert getattr(linear, size_attr) % world_size == 0
def shard(x, dim):
assert x.size(dim=dim) % world_size == 0
return torch.tensor_split(x, world_size, dim=dim)[rank]

def shard_qkv(qkv, dim, weight_splits):
q, k, v = qkv.split(weight_splits, dim=dim)
q = shard(q, dim)
k = shard(k, dim)
v = shard(v, dim)
return torch.cat((q,k,v), dim=dim)

# shard
if weight_splits:
# attention
assert len(weight_splits) == 3

if isinstance(linear, WeightOnlyInt4Linear):
sharded_weight = shard_qkv(linear.weight, shard_dim, [i//8 for i in weight_splits])
linear.scales_and_zeros = shard_qkv(linear.scales_and_zeros, 1 - shard_dim, weight_splits)
else:
sharded_weight = shard_qkv(linear.weight, shard_dim, weight_splits)
if hasattr(linear, "scales") and style == "colwise":
linear.scales = shard_qkv(linear.scales, 0, weight_splits)
else:
sharded_weight = shard(linear.weight, shard_dim)
if isinstance(linear, WeightOnlyInt4Linear):
linear.scales_and_zeros = shard(linear.scales_and_zeros, 1 - shard_dim)
if style == "rowwise":
assert linear.scales_and_zeros.shape[0] * 32 == sharded_weight.shape[1] * sharded_weight.shape[2] * sharded_weight.shape[3]
assert linear.scales_and_zeros.shape[1] == sharded_weight.shape[0] * 8
if hasattr(linear, "scales") and style == "colwise":
linear.scales = shard(linear.scales, 0)

# local_break()
linear.weight = nn.Parameter(sharded_weight, requires_grad=False)
setattr(linear, size_attr, getattr(linear, size_attr) // world_size)

# shape info should still be synced
# assert linear.weight.shape == (linear.out_features, linear.in_features)


def _apply_tp_ffn(mlp: FeedForward) -> None:
assert hasattr(mlp, "w1")
assert hasattr(mlp, "w3")
assert hasattr(mlp, "w2")

_apply_tp_linear(mlp.w1, "colwise")
_apply_tp_linear(mlp.w3, "colwise")
_apply_tp_linear(mlp.w2, "rowwise")

world_size = _get_world_size()
def all_reduce_hook(module, input, output):
dist.all_reduce(output)
return output
mlp.register_forward_hook(all_reduce_hook)


def _apply_tp_attn(attn: Attention) -> None:
assert hasattr(attn, "wqkv")
assert hasattr(attn, "wo")

kv_size = attn.n_local_heads * attn.head_dim
_apply_tp_linear(attn.wqkv, "colwise", [attn.dim, kv_size, kv_size])
_apply_tp_linear(attn.wo, "rowwise")

# overwrite
world_size = _get_world_size()
attn.n_head = attn.n_head // world_size
attn.dim = attn.dim // world_size
attn.head_dim = attn.dim // attn.n_head
attn.n_local_heads = attn.n_local_heads // world_size
def all_reduce_hook(module, input, output):
dist.all_reduce(output[0])
return output

attn.register_forward_hook(all_reduce_hook)


def _apply_tp_Transformer(Transformer: Transformer) -> None:
# overwrite config before Transformer.setup_cache is called
world_size = _get_world_size()
Transformer.config.n_head = Transformer.config.n_head // world_size
Transformer.config.dim = Transformer.config.dim // world_size
Transformer.config.n_local_heads = Transformer.config.n_local_heads // world_size


def apply_tp(model: Transformer) -> None:
_apply_tp_Transformer(model)
for block in model.layers:
# Apply to MLP
_apply_tp_ffn(block.feed_forward)
_apply_tp_attn(block.attention)
Loading