Skip to content

Commit

Permalink
Dynamic Float8 benchmarking llama (#1017)
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva authored Oct 7, 2024
1 parent 52d27a1 commit 92dd5f5
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 12 deletions.
12 changes: 9 additions & 3 deletions torchao/_models/llama/benchmarks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-int4 --write_result benchmark_results.txt



export MODEL_REPO=meta-llama/Meta-Llama-3-8B
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt
Expand All @@ -19,6 +17,14 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-int4 --write_result benchmark_results.txt

export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
# Runs on H100, float8 is not supported on CUDA arch < 8.9
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization float8wo --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization float8dq-tensor --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization float8dq-wo --write_result benchmark_results.txt

# OTHER BENCHMARKS

Expand Down Expand Up @@ -58,4 +64,4 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt --precision float16
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt
34 changes: 28 additions & 6 deletions torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@
fpx_weight_only,
uintx_weight_only,
unwrap_tensor_subclass,
float8_weight_only,
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
)
from torchao.quantization.observer import PerRow, PerTensor
from torchao._models._eval import TransformerEvalWrapper, InputRecorder
from torchao._models.llama.model import prepare_inputs_for_model

from tokenizer import get_tokenizer
import time
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer
from torchao._models.llama.model import prepare_inputs_for_model, TransformerBlock
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

def run_evaluation(
Expand Down Expand Up @@ -55,19 +58,16 @@ def run_evaluation(
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), str(tokenizer_path)
# Load Model and Tokenizer

print("Loading model ...")
t0 = time.time()
model = _load_model(checkpoint_path, "cpu", precision)

if max_length is None:
max_length = model.config.block_size

device_sync(device=device) # MKG
print(f"Time to load model: {time.time() - t0:.02f} seconds")
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)


if quantization:
if "int8wo" in quantization:
quantize_(model, int8_weight_only())
Expand Down Expand Up @@ -100,6 +100,9 @@ def run_evaluation(
from torchao.dtypes import MarlinSparseLayoutType
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
if "int4wo" in quantization and "gptq" in quantization:
# avoid circular imports
from torchao._models._eval import InputRecorder
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer
groupsize=int(quantization.split("-")[-2])
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
assert precision==torch.bfloat16, f"{quantization} requires precision or bfloat16 but got {precision}"
Expand All @@ -122,9 +125,24 @@ def run_evaluation(
else:
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(model)
if "float8wo" in quantization:
quantize_(model, float8_weight_only())
if "float8dq" in quantization:
granularity = str(quantization.split("-")[-1])
if granularity=="tensor":
granularity = PerTensor()
elif granularity=="row":
granularity = PerRow()
else:
if granularity=="float8dq":
granularity = PerTensor()
else:
raise ValueError(f"Unknown granularity {granularity}")
quantize_(model, float8_dynamic_activation_float8_weight(granularity=granularity))
if "autoround" in quantization:
from torchao.prototype.autoround.autoround_llm import quantize_model_with_autoround_
from transformers import AutoTokenizer
from torchao._models.llama.model import TransformerBlock

_tokenizer = AutoTokenizer.from_pretrained(checkpoint_path.parent)
# parse args from quantization string:
Expand Down Expand Up @@ -182,6 +200,9 @@ def run_evaluation(
if compile:
model = torch.compile(model, mode="max-autotune", fullgraph=True)
with torch.no_grad():
print("Running evaluation ...")
# avoid circular imports
from torchao._models._eval import TransformerEvalWrapper
TransformerEvalWrapper(
model=model.to(device),
tokenizer=tokenizer,
Expand Down Expand Up @@ -209,7 +230,8 @@ def run_evaluation(
"Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, "
"int4wo-<groupsize>-gptq, autoquant, autoquant-int4, int4wo-<groupsize>-hqq, "
"uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, "
"autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>-<grad_acc_steps>-<c>"
"autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>-<grad_acc_steps>-<c>, "
"float8wo, float8dq, float8saq"
),
)
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
Expand Down
12 changes: 10 additions & 2 deletions torchao/_models/llama/evals.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,13 @@ python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quanti
export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cpu # auto-round w/o quant_lm_head
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cuda-1 # auto-round w/ quant_lm_head
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoquant --tasks 'mmlu' 'truthfulqa_mc2'
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoquant --tasks 'winogrande' 'arc_challenge'
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int8wo
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-64
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization float8wo
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization float8dq-tensor
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization float8dq-row

# Testing on additional tasks
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int8wo --tasks 'winogrande' 'arc_challenge'
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int8wo --tasks 'mmlu' 'truthfulqa_mc2'
16 changes: 15 additions & 1 deletion torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,11 @@ def main(
fpx_weight_only,
uintx_weight_only,
autoquant,
unwrap_tensor_subclass
unwrap_tensor_subclass,
float8_weight_only,
float8_dynamic_activation_float8_weight,
)
from torchao.quantization.observer import PerTensor, PerRow
if "int8wo" in quantization:
quantize_(model, int8_weight_only())
if "int8dq" in quantization:
Expand Down Expand Up @@ -243,6 +246,17 @@ def main(
dtype = _NBITS_TO_DTYPE[nbits]
group_size = int(_quant_args[2])
quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq))
if "float8wo" in quantization:
quantize_(model, float8_weight_only())
if "float8dq" in quantization:
granularity = str(quantization.split("-")[-1])
if granularity=="tensor":
granularity = PerTensor()
elif granularity=="row":
granularity = PerRow()
else:
granularity = PerTensor()
quantize_(model, float8_dynamic_activation_float8_weight(granularity=granularity))
if "autoquant" in quantization:
if "autoquant-int4" == quantization:
model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST)
Expand Down
20 changes: 20 additions & 0 deletions torchao/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@ Benchmarks and evaluation are run on a machine with a single NVIDIA-A100-80GB GP
| | int4wo-64-GPTQ | 7.921 | 180.80 | 763.33 | 6.88 | 4.22 |
| | autoquant-int4hqq | 8.110 | 188.41 | 800.58 | 7.14 | 4.25 |

Benchmarks and evaluation for model meta-llama/Meta-Llama-3.1-8B are run on a machine with a single NVIDIA-H100 GPU using the scripts for [generation](../_models/llama/generate.py) and [eval](../_models/llama/eval.py). Evaluation was done using the lm_eval library for tasks/data.

| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) |
| ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- |
| Llama-3.1-8B | Base (bfloat16) | 7.54 | 126.90 | 1904.75 | 16.75 | 15.01 |
| | int8wo | 7.56 | 198.85 | 1495.41 | 11.05 | 7.52 |
| | int4wo-64 | 8.44 | 241.39 | 1019.14 | 7.08 | 4.22 |
| | float8wo | 7.60 | 178.46 | 1339.93 | 12.09 | 7.51 |
| | float8dq (PerTensor) | 7.62 | 116.40 | 873.58 | 11.14 | 7.51 |
| | float8dq (Per Row) | 7.61 | 154.63 | 1161.47 | 11.14 | 7.51 |

note: Int8 dynamic quantization works best on compute bound models like [SAM](https://github.com/pytorch-labs/segment-anything-fast) whereas Llama with batchsize=1 tends to be memory bound, thus the rather low performance.

For int4 we make heavy use of [tinygemm](https://github.com/pytorch/ao/blob/cb3bd8c674f2123af232a0231b5e38ddafa756a8/torchao/dtypes/aqt.py#L526) of `torch.ops.aten._weight_int4pack_mm` to bitpack into a layout optimized for tensor cores
Expand Down Expand Up @@ -121,6 +132,15 @@ from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtenso
change_linear_weights_to_int8_dqtensors(model)
```

#### A8W8 Float8 Dynamic Quantization

```python
# for torch 2.4+
from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight
from torchao.quantization.observer import PerTensor
quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerTensor()))
```

#### A16W6 Floating Point WeightOnly Quantization

```python
Expand Down

0 comments on commit 92dd5f5

Please sign in to comment.