Skip to content

Commit

Permalink
Update run_benchmark.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Seun-Ajayi authored Aug 15, 2024
1 parent 16b4f7b commit 777d637
Showing 1 changed file with 38 additions and 26 deletions.
64 changes: 38 additions & 26 deletions benchmark/run_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,25 @@ def __init__(self, model_name: str, model_class, tokenizer_class, config_class,
self.tokenizer = tokenizer_class.from_pretrained(self.model_name)
self.hgf_model_class = model_class
self.mlx_model_class = custom_model_class

def _set_device(self, device):
device_map = {
"mlx_cpu": mx.cpu,
"mlx_gpu": mx.gpu
}
mx.set_default_device(device_map[device])

def load_hgf_model(self, device):
return self.hgf_model_class.from_pretrained(self.model_name).to(device)

def load_mlx_model(self):
def load_mlx_model(self, device):
self._set_device(device)
mlx_model = self.mlx_model_class(self.config)
mlx_model.from_pretrained(self.model_name)
return mlx_model

def prepare_mlx_model_input(self, input_text):
def prepare_mlx_model_input(self, input_text, device):
self._set_device(device)
inputs_mlx = self.tokenizer(
input_text, return_tensors="np", padding=True, truncation=True
)
Expand All @@ -58,21 +67,21 @@ def prepare_hgf_model_input(self, input_text, device):
inputs_hgf = {key: v.to(device) for key, v in inputs_hgf.items()}
return inputs_hgf

def get_mlx_model_inference(self, model, inputs_mlx):
def get_mlx_model_inference(self, model, inputs_mlx, device):
self._set_device(device)
outputs = model(**inputs_mlx)
if model.config.architectures == ['BertForMaskedLM']:
output = model(**inputs_mlx)
_ = np.array(output.logits)
_ = np.array(outputs.logits)
else:
output = model(**inputs_mlx)
_ = np.array(output.last_hidden_state)
return output
_ = np.array(outputs.last_hidden_state)
return outputs

def get_hgf_model_inference(self, model, inputs_hgf):
def get_hgf_model_inference(self, model, inputs_hgf, device=None):
return model(**inputs_hgf)


class Benchmark:
def __init__(self, models, backends, batch_sizes=[1, 16, 32], iterations=5, input_lengths=[50, 100, 200, 500, 1000]):
def __init__(self, models, backends, batch_sizes=[1, 16, 32], iterations=12, input_lengths=[50, 100, 200, 500, 1000]):
self.models = models
self.backends = backends
self.batch_sizes=batch_sizes
Expand All @@ -83,14 +92,14 @@ def measure_inference_time(self, model, inputs, inference_func, backend):
times = []
for _ in range(self.iterations):
start_time = time.time()
_ = inference_func(model, inputs)
_ = inference_func(model, inputs, backend)
end_time = time.time()
times.append((end_time - start_time) * 1000) # Convert to milliseconds

if backend in ["cuda", "mps"]:
torch.cuda.empty_cache()

return times[1:]
return sorted(times)[:10]

def run_benchmark(self):
detailed_results = []
Expand All @@ -107,7 +116,6 @@ def run_benchmark(self):
custom_model_class=model_info['mlx_class']
)

mlx_model = model_instance.load_mlx_model()
for input_length in self.input_lengths:
for batch_size in self.batch_sizes:
input_text = generate_inputs(input_length, batch_size)
Expand All @@ -117,18 +125,12 @@ def run_benchmark(self):
tqdm.write(f"------ Running {model_info['name']} on {backend} with {input_length} chars and batch size {batch_size}... ------")

if backend == "mlx_cpu":
mx.set_default_device(mx.cpu)
inputs_mlx = model_instance.prepare_mlx_model_input(input_text)
inputs_mlx = model_instance.prepare_mlx_model_input(input_text, backend)
mlx_model = model_instance.load_mlx_model(backend)
mlx_inference_times = self.measure_inference_time(mlx_model, inputs_mlx, model_instance.get_mlx_model_inference, backend)
elif backend == "mlx_gpu":
mx.set_default_device(mx.gpu)
inputs_mlx = model_instance.prepare_mlx_model_input(input_text)
mlx_inference_times = self.measure_inference_time(mlx_model, inputs_mlx, model_instance.get_mlx_model_inference, backend)
elif backend == "mlx_gpu_compile":
mx.set_default(mx.gpu)
mx.compile()
inputs_mlx = model_instance.prepare_mlx_model_input(input_text)
mlx_model.compile()
inputs_mlx = model_instance.prepare_mlx_model_input(input_text, backend)
mlx_model = model_instance.load_mlx_model(backend)
mlx_inference_times = self.measure_inference_time(mlx_model, inputs_mlx, model_instance.get_mlx_model_inference, backend)
elif backend == "torch_cpu":
device = torch.device("cpu")
Expand Down Expand Up @@ -240,7 +242,7 @@ def save_results(
model_info = result['model']
average_times[model_info][result['backend']] = result['average_time']

with open("benchmark/benchmark_results.md", "w") as f:
with open("benchmark/test3_mlx_benchmark_results.md", "w") as f:
f.write("## Detailed Benchmark\n")
f.write("Detailed runtime benchmark of model inferences, measured in milliseconds.\n\n")

Expand Down Expand Up @@ -272,7 +274,7 @@ def save_results(

def main():
parser = argparse.ArgumentParser(description="Benchmark model inferences on different backends.")
parser.add_argument("--backends", nargs="+", default=["mlx_cpu", "mlx_gpu", "torch_cpu"],
parser.add_argument("--backends", nargs="+", default=["mlx_gpu"],
help="List of backends to benchmark on. E.g., --backends mlx_cpu mlx_gpu torch_cpu torch_cuda torch_mps")
parser.add_argument('--iterations', type=int, default=11, help='Number of runs for each benchmark')
parser.add_argument("--input_lengths", nargs="+", type=int, default=[50, 100, 200, 500], help="List of input character lengths.")
Expand All @@ -285,6 +287,7 @@ def main():
assert torch.backends.mps.is_available(), "MPS backend not available."
if "torch_cuda" in args.backends:
assert torch.cuda.is_available(), "CUDA device not found."



models = [
Expand All @@ -295,6 +298,13 @@ def main():
'config_class': BertConfig,
'mlx_class': MlxBertForMaskedLM,
},
{
'name': 'bert-large-uncased',
'hgf_class': BertForMaskedLM,
'tokenizer_class': BertTokenizer,
'config_class': BertConfig,
'mlx_class': MlxBertForMaskedLM,
},
{
'name': 'roberta-base',
'hgf_class': RobertaModel,
Expand All @@ -308,7 +318,7 @@ def main():
'tokenizer_class': XLMRobertaTokenizer,
'config_class': XLMRobertaConfig,
'mlx_class': MlxXLMRobertaModel,
}
},
]

benchmark = Benchmark(
Expand All @@ -322,3 +332,5 @@ def main():

if __name__ == "__main__":
main()

# You can edit line 277 to compute the runtine on `mlx_cpu`

0 comments on commit 777d637

Please sign in to comment.