diff --git a/GPTQ.py b/GPTQ.py index e1279bd..d0d052d 100644 --- a/GPTQ.py +++ b/GPTQ.py @@ -129,6 +129,127 @@ def cuda(self): self.values = [val.cuda() if isinstance(val, torch.Tensor) else val for val in self.values] +class GPTQMultiTensor(torch.Tensor): + """ + """ + # todo need default shape/dtype + @staticmethod + def __new__(cls, input, **kwargs): + kwargs["dtype"]=kwargs.get("dtype", input.dtype) + shape = kwargs.pop("shape", input.shape) + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) + + def __init__(self, input, **kwargs): + self.values = [] + self.append(inp) + self.debug = False + + + def append(self, input) + if isinstance(input, (tuple, list)): + for inp in input: + self.values.append(inp) + elif isinstance(input, torch.Tensor): + self.values(input) + + # def __add__(self, other): + # for val in other.values: + # self.append(val) + + def count(self): + return len(self.values) + + def cuda(self): + self.values = [val.cuda() if isinstance(val, torch.Tensor) else val for val in self.values] + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None, skip_quant=False) + def tensors_to_cuda(args): + new_args = [] + for x in args: + new_args.append(x.cuda() if isinstance(x, torch.Tensor) else x) + return new_args + + kwargs = {} if kwargs is None else kwargs + # combine args and kwargs + flat_args, spec = tree_flatten((args, kwargs)) + # move single tensors to cuda + flat_args = tensors_to_cuda(flat_args) + # size of biggest MultiTensor + multi_tensor_size = max( + [x.count() if isinstance(x, GPTQMultiTensor) else 1 for x in flat_args] + ) + # convert [a, MultiTensor(b,b,b), MultiTensor(c,c,c)] => [a,b,c], [a,b,c] [a,b,c] + grouped_args = list( + zip( + *[x.values if isinstance(x, GPTQMultiTensor) else [x] * multi_tensor_size for x in flat_args] + ) + ) + + quantize_linear = ( + func is nn.functional.linear + # and id(args[1]) in self.id_to_name + and not skip_quant + # and not (self.skip_layer_func) + ) + + # run function for each of the multitensors and return a multitensor + if not quantize_linear: + outputs = [] + for inp in transposed_args: + inp = tensors_to_cuda(inp) + cur_args, cur_kwargs = tree_unflatten(inp, spec) + with torch._C.DisableTorchFunctionSubclass(): + out = func(*cur_args, **cur_kwargs) + outputs.append(out.cpu() if isinstance(out, torch.Tensor) else out) + return cls(outputs) + + total_batches = 0 + H=0 + for inp in transposed_args: + inp = tensors_to_cuda(inp) + cur_args, cur_kwargs = tree_unflatten(inp, spec) + x = cur_args[0].float() + shape = x.shape + n = 1 if len(shape) == 2 else shape[0] + H*= total_batches / (total_batches + n) + total_batches += n + x = ( + (2 / total_batches) ** (1 / 2) * + x.reshape(-1, shape[-1]).t().float() + + ) + H += x.matmul(x.t()) + W = args[1].to(H.device) + Q, DQ, qparams = args[0].faster_quant(H, W.detach()) + + new_out = func(args[0], DQ, *args[2:], kwargs, skip_quant = True) + if args[0].debug: + breakpoint() + return new_out + + + + if func is torch.nn.functional.linear: + + inputs, weight, bias = ( + args[0], + args[1], + args[2] if len(args)>2 else None + ) + if quantize_linear: + cls.do_gptq(input, weight) + return func(mat1, w_autoquant.weight, bias) + try: + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + except: + print(f"ERR: subclass doesn't implement {func}") + + + + + class GenericGPTQRunner(fx.Interpreter): """ This is a generic GPTQ runner that takes an existing model and applies GPTQ. @@ -150,7 +271,7 @@ def __init__( } # trace model for one input - one_input = [multi.values[0].cpu() for multi in inputs] + one_input = tuple([multi.values[0].cpu() for multi in inputs]) exported_model = torch._dynamo.export( model.cpu(), aten_graph=True, pre_dispatch=True, tracing_mode="fake" )(*one_input) diff --git a/run.sh b/run.sh new file mode 100644 index 0000000..2302c3e --- /dev/null +++ b/run.sh @@ -0,0 +1,20 @@ +export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf + +# python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --compile # working +# echo "base" +export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf +python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-gptq --calibration_tasks wikitext --calibration_limit 5 +python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4-gptq.g32.cuda.pth --tasks wikitext --limit 5 + +# python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.pth --compile +# echo "quant good" + +# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4 +# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --tasks wikitext --limit 5 + +# export MODEL_REPO=meta-llama/Llama-2-70b-chat-hf +# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4 +# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --tasks wikitext --limit 5 +# ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=8 generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth + +# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-gptq --calibration_tasks wikitext --calibration_limit 5