-
Notifications
You must be signed in to change notification settings - Fork 129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
cleaned up and tested tp support #976
Open
debajyotidatta
wants to merge
2
commits into
pytorch:main
Choose a base branch
from
debajyotidatta:dd-add-tp-support
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
|
||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# copied from https://github.com/pytorch-labs/gpt-fast/blob/main/tp.py | ||
import os | ||
from typing import List, Optional | ||
|
||
import torch | ||
import torch.distributed as dist | ||
from torch import nn | ||
if os.uname().sysname != "Darwin": | ||
from torch.distributed import _functional_collectives as funcol | ||
else: | ||
# Distributed is not supported on MacOS | ||
funcol = None | ||
|
||
from torchao._models.llama.model import Attention, FeedForward, Transformer | ||
from torchao.quantization.GPTQ import WeightOnlyInt4Linear | ||
|
||
|
||
def _get_rank() -> int: | ||
return int(os.environ.get("LOCAL_RANK", "0")) | ||
|
||
def is_local(): | ||
return _get_rank() == 0 | ||
|
||
def local_break(): | ||
if is_local(): | ||
breakpoint() | ||
dist.barrier() | ||
|
||
def _get_world_size() -> int: | ||
return int(os.environ.get("LOCAL_WORLD_SIZE", "1")) | ||
|
||
def maybe_init_dist() -> Optional[int]: | ||
try: | ||
# provided by torchrun | ||
rank = _get_rank() | ||
world_size = _get_world_size() | ||
|
||
if world_size < 2: | ||
# too few gpus to parallelize, tp is no-op | ||
return None | ||
except KeyError: | ||
# not run via torchrun, no-op | ||
return None | ||
|
||
torch.cuda.set_device(rank) | ||
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) | ||
return rank | ||
|
||
|
||
def _apply_tp_linear(linear: nn.Linear, style: str, weight_splits: List[int] = []) -> None: | ||
rank = _get_rank() | ||
world_size = _get_world_size() | ||
|
||
# Linear's weight matrix is transposed, and is of shape | ||
# (linear.out_features, linear.in_features) | ||
dim_lookup = { | ||
"colwise": (0, "out_features"), | ||
"rowwise": (1, "in_features") | ||
} | ||
assert style in dim_lookup | ||
shard_dim, size_attr = dim_lookup[style] | ||
|
||
# ensure we can shard evenly | ||
assert getattr(linear, size_attr) % world_size == 0 | ||
def shard(x, dim): | ||
assert x.size(dim=dim) % world_size == 0 | ||
return torch.tensor_split(x, world_size, dim=dim)[rank] | ||
|
||
def shard_qkv(qkv, dim, weight_splits): | ||
q, k, v = qkv.split(weight_splits, dim=dim) | ||
q = shard(q, dim) | ||
k = shard(k, dim) | ||
v = shard(v, dim) | ||
return torch.cat((q,k,v), dim=dim) | ||
|
||
# shard | ||
if weight_splits: | ||
# attention | ||
assert len(weight_splits) == 3 | ||
|
||
if isinstance(linear, WeightOnlyInt4Linear): | ||
sharded_weight = shard_qkv(linear.weight, shard_dim, [i//8 for i in weight_splits]) | ||
linear.scales_and_zeros = shard_qkv(linear.scales_and_zeros, 1 - shard_dim, weight_splits) | ||
else: | ||
sharded_weight = shard_qkv(linear.weight, shard_dim, weight_splits) | ||
if hasattr(linear, "scales") and style == "colwise": | ||
linear.scales = shard_qkv(linear.scales, 0, weight_splits) | ||
else: | ||
sharded_weight = shard(linear.weight, shard_dim) | ||
if isinstance(linear, WeightOnlyInt4Linear): | ||
linear.scales_and_zeros = shard(linear.scales_and_zeros, 1 - shard_dim) | ||
if style == "rowwise": | ||
assert linear.scales_and_zeros.shape[0] * 32 == sharded_weight.shape[1] * sharded_weight.shape[2] * sharded_weight.shape[3] | ||
assert linear.scales_and_zeros.shape[1] == sharded_weight.shape[0] * 8 | ||
if hasattr(linear, "scales") and style == "colwise": | ||
linear.scales = shard(linear.scales, 0) | ||
|
||
# local_break() | ||
linear.weight = nn.Parameter(sharded_weight, requires_grad=False) | ||
setattr(linear, size_attr, getattr(linear, size_attr) // world_size) | ||
|
||
# shape info should still be synced | ||
# assert linear.weight.shape == (linear.out_features, linear.in_features) | ||
|
||
|
||
def _apply_tp_ffn(mlp: FeedForward) -> None: | ||
assert hasattr(mlp, "w1") | ||
assert hasattr(mlp, "w3") | ||
assert hasattr(mlp, "w2") | ||
|
||
_apply_tp_linear(mlp.w1, "colwise") | ||
_apply_tp_linear(mlp.w3, "colwise") | ||
_apply_tp_linear(mlp.w2, "rowwise") | ||
|
||
world_size = _get_world_size() | ||
def all_reduce_hook(module, input, output): | ||
dist.all_reduce(output) | ||
return output | ||
mlp.register_forward_hook(all_reduce_hook) | ||
|
||
|
||
def _apply_tp_attn(attn: Attention) -> None: | ||
assert hasattr(attn, "wqkv") | ||
assert hasattr(attn, "wo") | ||
|
||
kv_size = attn.n_local_heads * attn.head_dim | ||
_apply_tp_linear(attn.wqkv, "colwise", [attn.dim, kv_size, kv_size]) | ||
_apply_tp_linear(attn.wo, "rowwise") | ||
|
||
# overwrite | ||
world_size = _get_world_size() | ||
attn.n_head = attn.n_head // world_size | ||
attn.dim = attn.dim // world_size | ||
attn.head_dim = attn.dim // attn.n_head | ||
attn.n_local_heads = attn.n_local_heads // world_size | ||
def all_reduce_hook(module, input, output): | ||
dist.all_reduce(output[0]) | ||
return output | ||
|
||
attn.register_forward_hook(all_reduce_hook) | ||
|
||
|
||
def _apply_tp_Transformer(Transformer: Transformer) -> None: | ||
# overwrite config before Transformer.setup_cache is called | ||
world_size = _get_world_size() | ||
Transformer.config.n_head = Transformer.config.n_head // world_size | ||
Transformer.config.dim = Transformer.config.dim // world_size | ||
Transformer.config.n_local_heads = Transformer.config.n_local_heads // world_size | ||
|
||
|
||
def apply_tp(model: Transformer) -> None: | ||
_apply_tp_Transformer(model) | ||
for block in model.layers: | ||
# Apply to MLP | ||
_apply_tp_ffn(block.feed_forward) | ||
_apply_tp_attn(block.attention) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we are not using
WeightOnlyInt4Linear
any more in torchao I think, is this just for GPTQ?