Skip to content

Commit

Permalink
fixing GPTQ
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 7a9ae0d1d03794c5e5d85e3674fc88d2813eaf23
Pull Request resolved: #147
  • Loading branch information
HDCharles committed Mar 26, 2024
1 parent 93dab0e commit 94a75dc
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 1 deletion.
123 changes: 122 additions & 1 deletion GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,127 @@ def cuda(self):
self.values = [val.cuda() if isinstance(val, torch.Tensor) else val for val in self.values]


class GPTQMultiTensor(torch.Tensor):
"""
"""
# todo need default shape/dtype
@staticmethod
def __new__(cls, input, **kwargs):
kwargs["dtype"]=kwargs.get("dtype", input.dtype)
shape = kwargs.pop("shape", input.shape)
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)

def __init__(self, input, **kwargs):
self.values = []
self.append(inp)
self.debug = False


def append(self, input)
if isinstance(input, (tuple, list)):
for inp in input:
self.values.append(inp)
elif isinstance(input, torch.Tensor):
self.values(input)

# def __add__(self, other):
# for val in other.values:
# self.append(val)

def count(self):
return len(self.values)

def cuda(self):
self.values = [val.cuda() if isinstance(val, torch.Tensor) else val for val in self.values]

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None, skip_quant=False)
def tensors_to_cuda(args):
new_args = []
for x in args:
new_args.append(x.cuda() if isinstance(x, torch.Tensor) else x)
return new_args

kwargs = {} if kwargs is None else kwargs
# combine args and kwargs
flat_args, spec = tree_flatten((args, kwargs))
# move single tensors to cuda
flat_args = tensors_to_cuda(flat_args)
# size of biggest MultiTensor
multi_tensor_size = max(
[x.count() if isinstance(x, GPTQMultiTensor) else 1 for x in flat_args]
)
# convert [a, MultiTensor(b,b,b), MultiTensor(c,c,c)] => [a,b,c], [a,b,c] [a,b,c]
grouped_args = list(
zip(
*[x.values if isinstance(x, GPTQMultiTensor) else [x] * multi_tensor_size for x in flat_args]
)
)

quantize_linear = (
func is nn.functional.linear
# and id(args[1]) in self.id_to_name
and not skip_quant
# and not (self.skip_layer_func)
)

# run function for each of the multitensors and return a multitensor
if not quantize_linear:
outputs = []
for inp in transposed_args:
inp = tensors_to_cuda(inp)
cur_args, cur_kwargs = tree_unflatten(inp, spec)
with torch._C.DisableTorchFunctionSubclass():
out = func(*cur_args, **cur_kwargs)
outputs.append(out.cpu() if isinstance(out, torch.Tensor) else out)
return cls(outputs)

total_batches = 0
H=0
for inp in transposed_args:
inp = tensors_to_cuda(inp)
cur_args, cur_kwargs = tree_unflatten(inp, spec)
x = cur_args[0].float()
shape = x.shape
n = 1 if len(shape) == 2 else shape[0]
H*= total_batches / (total_batches + n)
total_batches += n
x = (
(2 / total_batches) ** (1 / 2) *
x.reshape(-1, shape[-1]).t().float()

)
H += x.matmul(x.t())
W = args[1].to(H.device)
Q, DQ, qparams = args[0].faster_quant(H, W.detach())

new_out = func(args[0], DQ, *args[2:], kwargs, skip_quant = True)
if args[0].debug:
breakpoint()
return new_out



if func is torch.nn.functional.linear:

inputs, weight, bias = (
args[0],
args[1],
args[2] if len(args)>2 else None
)
if quantize_linear:
cls.do_gptq(input, weight)
return func(mat1, w_autoquant.weight, bias)
try:
with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
except:
print(f"ERR: subclass doesn't implement {func}")





class GenericGPTQRunner(fx.Interpreter):
"""
This is a generic GPTQ runner that takes an existing model and applies GPTQ.
Expand All @@ -150,7 +271,7 @@ def __init__(
}

# trace model for one input
one_input = [multi.values[0].cpu() for multi in inputs]
one_input = tuple([multi.values[0].cpu() for multi in inputs])
exported_model = torch._dynamo.export(
model.cpu(), aten_graph=True, pre_dispatch=True, tracing_mode="fake"
)(*one_input)
Expand Down
20 changes: 20 additions & 0 deletions run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf

# python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --compile # working
# echo "base"
export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-gptq --calibration_tasks wikitext --calibration_limit 5
python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4-gptq.g32.cuda.pth --tasks wikitext --limit 5

# python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.pth --compile
# echo "quant good"

# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4
# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --tasks wikitext --limit 5

# export MODEL_REPO=meta-llama/Llama-2-70b-chat-hf
# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4
# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --tasks wikitext --limit 5
# ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=8 generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth

# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-gptq --calibration_tasks wikitext --calibration_limit 5

0 comments on commit 94a75dc

Please sign in to comment.