Skip to content

Commit

Permalink
Update inference benchmarking script (#55)
Browse files Browse the repository at this point in the history
  • Loading branch information
abhi-mosaic authored May 6, 2023
1 parent 3d51eaf commit 3959eac
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 85 deletions.
4 changes: 2 additions & 2 deletions scripts/inference/benchmarking/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
This folder provides scripts for benchmarking the inference performance of deep learning models. Currently, we support benchmarking with Deepspeed and Huggingface generate.

## Scripts
The repository includes the benchmark.py script, along with associated `.yaml files,` to run benchmarking. The script takes a `.yaml` file as input and outputs the latency (in seconds) and tokens per second for each run. We average over `num_runs=5`, which is defined in the `.yaml` file. Additionally, we iterate over various `batch_sizes`, `input_lengths`, and `output_lengths` to produce varying throughput metrics.
The repository includes the benchmark.py script, along with associated `.yaml files,` to run benchmarking. The script takes a `.yaml` file as input and outputs the latency (in seconds) and tokens per second for each run. We average over `num_batches=5`, which is defined in the `.yaml` file. Additionally, we iterate over various `batch_sizes`, `input_lengths`, and `output_lengths` to produce varying throughput metrics.

## Usage
To use the `benchmark.py` script, you need to provide a `.yaml` file that specifies the model configuration and other parameters such as the path to the model checkpoint and the input data. You can modify the default `.yaml` files provided in the repository or create your own `.yaml` file.

To run the benchmarking script, use the following command:

`python benchmark.py config.yaml`
`python benchmark.py yamls/1b.yaml`

To run the scripts on [The MosaicML platform](https://www.mosaicml.com/blog/mosaicml-cloud-demo) we've also included scripts and associated `.yaml files` in the `mcloud` folder.
119 changes: 47 additions & 72 deletions scripts/inference/benchmarking/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,52 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import contextlib
import sys
import time
from contextlib import nullcontext

import numpy as np
import torch
# You can use this to load the model weights
from omegaconf import OmegaConf as om

from llmfoundry import COMPOSER_MODEL_REGISTRY


def get_precision(precision):
if precision == 'fp32':
def get_dtype(dtype):
if dtype == 'fp32':
return torch.float32
elif precision == 'fp16':
elif dtype == 'fp16':
return torch.float16
elif precision == 'bf16':
elif dtype == 'bf16':
return torch.bfloat16
else:
raise NotImplementedError(
f'Precision of type {precision} is not supported. '
f'We only support fp32, amp_fp16, and amp_bf16 currently')
f'dtype {dtype} is not supported. '
f'We only support fp32, fp16, and bf16 currently')


def compare_precision(precision, param_dtype):
if precision != param_dtype:
def compare_dtype(dtype, param_dtype):
if dtype != param_dtype:
raise ValueError(
f'Precision type is: {precision} but model dtype is: {param_dtype}. '
f"The expected precision and model precision don't match.")
f'dtype type is: {dtype} but model dtype is: {param_dtype}. '
f"The expected dtype and model dtype don't match.")


def main(config):
model_dtype = get_precision(config.model_dtype)
autocast_precision = None
if config.autocast_precision is not None:
autocast_precision = get_precision(config.autocast_precision)
if config.device is not None:
device = config.device
else:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model_dtype = get_dtype(config.model_dtype)
print(f'Using device={device} and dtype={model_dtype}...')

if config.autocast_dtype is not None:
autocast_dtype = get_dtype(config.autocast_dtype)
autocast_context = torch.autocast(device, autocast_dtype)
print(f'Using autocast with dtype={autocast_dtype}...')
else:
autocast_context = nullcontext()
print('NOT using autocast...')

inference_config = {
'replace_with_kernel_inject': True,
Expand All @@ -51,9 +60,7 @@ def main(config):

composer_model = COMPOSER_MODEL_REGISTRY[config.model.name](
config.model, config.tokenizer)

model = composer_model.model

model.eval()

if config.use_deepspeed:
Expand All @@ -62,90 +69,58 @@ def main(config):

# Checking if deepspeed casts dtypes correctly
for _, p in model.named_parameters():
compare_precision(model_dtype, p.dtype)
compare_dtype(model_dtype, p.dtype)
break
else:
model.to(torch.cuda.current_device())
model.to(model_dtype)
model.to(device=device, dtype=model_dtype)

n_params = sum(p.numel() for p in model.parameters())
print('n_params is: ', n_params)

print('name, latency (s), tokens / s, output token time (ms)')
print(
'name, latency (s), throughput (tokens/s), latency_per_sequence_output_token (ms)'
)
print('=' * 75)

stats = []
for batch_size in config.batch_sizes:
for input_length in config.input_lengths:
for output_length in config.output_lengths:
times = []

batch = torch.randint(
0,
config.model.vocab_size - 1,
size=(
batch_size,
input_length)).to(f'cuda:{torch.cuda.current_device()}')
batch = torch.randint(0,
config.model.vocab_size - 1,
size=(batch_size,
input_length)).to(device)

# We're just going to have generate eos, padding tokens be
# ignored by HF generate
batch = batch.to(torch.long)
attention_mask = torch.ones_like(batch)

torch.cuda.synchronize()

for i in range(config.num_runs + 1):
start_time = time.time()
start_time = 0
for i in range(config.num_batches + config.num_warmup_batches):
if i == config.num_warmup_batches:
torch.cuda.synchronize()
start_time = time.time()
with torch.no_grad():
precision_context = contextlib.nullcontext()
if autocast_precision is not None and autocast_precision in [
'fp16', 'bf16'
]:
precision_context = torch.cuda.amp.autocast(
True, dtype=autocast_precision)

with precision_context:
with autocast_context:
model.generate(batch,
max_new_tokens=output_length,
use_cache=True,
use_cache=config.use_cache,
attention_mask=attention_mask,
eos_token_id=None,
pad_token_id=None)

torch.cuda.synchronize()

# We noticed there sometimes might be a small bit of startup time
# so we only start to benchmark after some number of batches
if i >= config.num_warmup_batches:
times.append(time.time() - start_time)
torch.cuda.synchronize()
mean_time = (time.time() - start_time) / config.num_batches

num_output_tokens = output_length * batch_size
mean_time = np.mean(times)
tokens_per_second = num_output_tokens / float(mean_time)
ms_per_seq_output_token = float(
mean_time) * 1000 / num_output_tokens

result = (
f'{config.benchmark_name}_{batch_size}_{input_length}_{output_length}',
f'{mean_time:.3f}', f'{tokens_per_second:.3f}',
f'{ms_per_seq_output_token:.3f}')

run_name, latency, tokens_per_second, ms_per_seq_output_token = result
tokens_per_second = num_output_tokens / mean_time
ms_per_seq_output_token = mean_time * 1000 / output_length

run_name = f'{config.benchmark_name}_{batch_size}_{input_length}_{output_length}'
print(
f'{run_name}, {latency}, {tokens_per_second}, {ms_per_seq_output_token}'
f'{run_name}, {mean_time:.3f}, {tokens_per_second:.3f}, {ms_per_seq_output_token:.3f}'
)

stats.append(result)

print('=' * 75)
print('name, latency (s), tokens / s, output token time (ms)')
for val in stats:
run_name, latency, tokens_per_second, ms_per_seq_output_token = val
print(
f'{run_name}, latency (s) {latency}, tokens per second {tokens_per_second}, output token time (ms) {ms_per_seq_output_token}'
)


if __name__ == '__main__':
yaml_path, args_list = sys.argv[1], sys.argv[2:]
Expand Down
10 changes: 5 additions & 5 deletions scripts/inference/benchmarking/yamls/1b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ tokenizer:
name: ${tokenizer_name}
kwargs:
model_max_length: ${max_seq_len}
non_eos_token_id: 17

model:
name: mpt_causal_lm
Expand All @@ -27,14 +26,15 @@ model:
attn_config:
attn_impl: triton

autocast_precision: bf16
device: null
model_dtype: bf16
autocast_dtype: null
use_deepspeed: false

batch_sizes: [1, 2, 4, 8, 16, 32, 64]
input_lengths: [128]
output_lengths: [8]
num_runs: 5
use_cache: true

num_batches: 5
num_warmup_batches: 3

use_deepspeed: false
12 changes: 6 additions & 6 deletions scripts/inference/benchmarking/yamls/7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ tokenizer:
name: ${tokenizer_name}
kwargs:
model_max_length: ${max_seq_len}
non_eos_token_id: 17

model:
name: mpt_causal_lm
Expand All @@ -27,14 +26,15 @@ model:
attn_config:
attn_impl: triton

autocast_precision: bf16
model_dtype: fp32
device: null
model_dtype: bf16
autocast_dtype: null
use_deepspeed: false

batch_sizes: [1, 2, 4, 8, 16, 32, 64]
input_lengths: [128]
output_lengths: [8]
num_runs: 5
use_cache: true

num_batches: 5
num_warmup_batches: 3

use_deepspeed: false

0 comments on commit 3959eac

Please sign in to comment.