From c3884daff6317380aa0255a46d8c7e69354cbf97 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Tue, 7 May 2024 13:39:24 +0200 Subject: [PATCH] update llm perf --- ...yaml => update_llm_perf_cuda_pytorch.yaml} | 10 +- .../update_open_llm_leaderboard.yaml | 36 ++++ Makefile | 4 +- llm_perf/constants.py | 36 ---- ...rch.py => update_llm_perf_cuda_pytorch.py} | 85 +++++---- llm_perf/update_open_llm_leaderboard.py | 42 +++++ llm_perf/utils.py | 175 +++++++++++++++--- 7 files changed, 284 insertions(+), 104 deletions(-) rename .github/workflows/{llm_perf_cuda_pytorch.yaml => update_llm_perf_cuda_pytorch.yaml} (89%) create mode 100644 .github/workflows/update_open_llm_leaderboard.yaml delete mode 100644 llm_perf/constants.py rename llm_perf/{benchmark_cuda_pytorch.py => update_llm_perf_cuda_pytorch.py} (61%) create mode 100644 llm_perf/update_open_llm_leaderboard.py diff --git a/.github/workflows/llm_perf_cuda_pytorch.yaml b/.github/workflows/update_llm_perf_cuda_pytorch.yaml similarity index 89% rename from .github/workflows/llm_perf_cuda_pytorch.yaml rename to .github/workflows/update_llm_perf_cuda_pytorch.yaml index 5b849f56..64bbe4d9 100644 --- a/.github/workflows/llm_perf_cuda_pytorch.yaml +++ b/.github/workflows/update_llm_perf_cuda_pytorch.yaml @@ -1,9 +1,11 @@ -name: LLM Perf Benchmarks - CUDA PyTorch +name: Update LLM Perf Benchmarks - CUDA PyTorch on: workflow_dispatch: + push: + branches: + - update-llm-perf schedule: - # Every day at 00:00 UTC - cron: "0 0 * * *" concurrency: @@ -53,7 +55,7 @@ jobs: --volume ${{ github.workspace }}:/workspace --workdir /workspace run: | - pip install packaging && pip install flash-attn einops scipy auto-gptq optimum bitsandbytes autoawq + pip install packaging && pip install flash-attn einops scipy auto-gptq optimum bitsandbytes autoawq codecarbon pip install -U transformers huggingface_hub[hf_transfer] - pip install -e .[codecarbon] + pip install -e . python llm_perf/benchmark_cuda_pytorch.py diff --git a/.github/workflows/update_open_llm_leaderboard.yaml b/.github/workflows/update_open_llm_leaderboard.yaml new file mode 100644 index 00000000..606916e3 --- /dev/null +++ b/.github/workflows/update_open_llm_leaderboard.yaml @@ -0,0 +1,36 @@ +name: Update Open LLM Leaderboard + +on: + workflow_dispatch: + push: + branches: + - update-llm-perf + schedule: + - cron: "0 0 * * *" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + update_open_llm_leaderboard: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v3 + + - name: Set up Python 3.10 + uses: actions/setup-python@v3 + with: + python-version: "3.10" + + - name: Install requirements + run: | + pip install --upgrade pip + pip install pandas huggingface-hub + + - name: Update Open LLM Leaderboard + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + run: | + python scripts/update_open_llm_leaderboard.py diff --git a/Makefile b/Makefile index 2ab26568..b80d8144 100644 --- a/Makefile +++ b/Makefile @@ -173,9 +173,9 @@ test_cli_rocm_pytorch_single_gpu: # llm-perf install_llm_perf_cuda_pytorch: - pip install packaging && pip install flash-attn einops scipy auto-gptq optimum bitsandbytes autoawq + pip install packaging && pip install flash-attn einops scipy auto-gptq optimum bitsandbytes autoawq codecarbon pip install -U transformers huggingface_hub[hf_transfer] - pip install -e .[codecarbon] + pip install -e . run_llm_perf_cuda_pytorch_unquantized: SUBSET=unquantized python llm_perf/benchmark_cuda_pytorch.py diff --git a/llm_perf/constants.py b/llm_perf/constants.py deleted file mode 100644 index 13d4973d..00000000 --- a/llm_perf/constants.py +++ /dev/null @@ -1,36 +0,0 @@ -import pandas as pd - -INPUT_SHAPES = {"batch_size": 1, "sequence_length": 256} -GENERATE_KWARGS = {"max_new_tokens": 64, "min_new_tokens": 64} - -OPEN_LLM_DATAFRAME = pd.read_csv("hf://datasets/optimum/llm-perf-dataset/open-llm.csv") -PRETRAINED_MODELS_LIST = OPEN_LLM_DATAFRAME.sort_values("Size", ascending=True)["Model"].tolist() - -CANONICAL_ORGANIZATIONS = [ - # big companies - *["google", "facebook", "meta", "meta-llama", "microsoft", "Intel", "TencentARC", "Salesforce"], - # collectives - *["EleutherAI", "tiiuae", "NousResearch", "Open-Orca"], - # HF related - ["bigcode", "HuggingFaceH4"], - # community members - ["teknium"], - # startups - *[ - "mistral-community", - "openai-community", - "togethercomputer", - "stabilityai", - "CohereForAI", - "databricks", - "mistralai", - "internlm", - "Upstage", - "xai-org", - "Phind", - "01-ai", - "Deci", - "Qwen", - ], -] -CANONICAL_MODELS_LIST = [model for model in PRETRAINED_MODELS_LIST if model.split("/")[0] in CANONICAL_ORGANIZATIONS] diff --git a/llm_perf/benchmark_cuda_pytorch.py b/llm_perf/update_llm_perf_cuda_pytorch.py similarity index 61% rename from llm_perf/benchmark_cuda_pytorch.py rename to llm_perf/update_llm_perf_cuda_pytorch.py index 64a40c60..6aaab05c 100644 --- a/llm_perf/benchmark_cuda_pytorch.py +++ b/llm_perf/update_llm_perf_cuda_pytorch.py @@ -2,21 +2,24 @@ from itertools import product from logging import getLogger -from llm_perf.constants import CANONICAL_MODELS_LIST, GENERATE_KWARGS, INPUT_SHAPES, PRETRAINED_MODELS_LIST -from llm_perf.utils import common_errors_reporter, is_experiment_conducted, is_experiment_not_supported -from optimum_benchmark.backends.pytorch.config import PyTorchConfig -from optimum_benchmark.benchmarks.inference.config import InferenceConfig -from optimum_benchmark.experiment import ExperimentConfig, launch -from optimum_benchmark.launchers.process.config import ProcessConfig +from llm_perf.utils import ( + CANONICAL_PRETRAINED_OPEN_LLM_LIST, + GENERATE_KWARGS, + INPUT_SHAPES, + OPEN_LLM_LIST, + PRETRAINED_OPEN_LLM_LIST, + errors_reporter, + is_benchmark_conducted, + is_benchmark_supported, +) +from optimum_benchmark import Benchmark, BenchmarkConfig, InferenceConfig, ProcessConfig, PyTorchConfig from optimum_benchmark.logging_utils import setup_logging CWD = os.getcwd() MACHINE = os.getenv("MACHINE", "1xA100") SUBSET = os.getenv("SUBSET", "unquantized") -CANONICAL_MODELS_ONLY = os.getenv("CANONICAL_MODELS_ONLY", "1") == "1" PUSH_REPO_ID = f"optimum-benchmark/llm-perf-pytorch-cuda-{SUBSET}-{MACHINE}" - ATTENTION_COFIGS = ["eager", "sdpa", "flash_attention_2"] if SUBSET == "unquantized": WEIGHTS_CONFIGS = { @@ -79,25 +82,26 @@ } -setup_logging() LOGGER = getLogger("llm-perf-backend") +LOGGER.info(f"len(OPEN_LLM_LIST): {len(OPEN_LLM_LIST)}") +LOGGER.info(f"len(PRETRAINED_OPEN_LLM_LIST): {len(PRETRAINED_OPEN_LLM_LIST)}") +LOGGER.info(f"len(CANONICAL_PRETRAINED_OPEN_LLM_LIST): {len(CANONICAL_PRETRAINED_OPEN_LLM_LIST)}") def benchmark_cuda_pytorch(model, attn_implementation, weights_config): + benchmark_name = f"{weights_config}-{attn_implementation}" + subfolder = f"{benchmark_name}/{model.replace('/', '--')}" + torch_dtype = WEIGHTS_CONFIGS[weights_config]["torch_dtype"] quant_scheme = WEIGHTS_CONFIGS[weights_config]["quant_scheme"] quant_config = WEIGHTS_CONFIGS[weights_config]["quant_config"] - if is_experiment_not_supported(torch_dtype, attn_implementation): - LOGGER.info(f"Skipping experiment with model {model} since it is not supported") + if not is_benchmark_supported(weights_config, attn_implementation): + LOGGER.info(f"Skipping benchmark {benchmark_name} with model {model} since it is not supported") return - launcher_config = ProcessConfig( - start_method="spawn", - device_isolation=True, - device_isolation_action="error", - ) - benchmark_config = InferenceConfig( + launcher_config = ProcessConfig(device_isolation=True, device_isolation_action="kill") + scenario_config = InferenceConfig( memory=True, energy=True, latency=True, @@ -110,7 +114,7 @@ def benchmark_cuda_pytorch(model, attn_implementation, weights_config): backend_config = PyTorchConfig( model=model, device="cuda", - device_ids="0", + device_ids="4", no_weights=True, library="transformers", task="text-generation", @@ -120,38 +124,41 @@ def benchmark_cuda_pytorch(model, attn_implementation, weights_config): attn_implementation=attn_implementation, ) - experiment_name = f"{weights_config}-{attn_implementation}" - subfolder = f"{experiment_name}/{model.replace('/', '--')}" - - experiment_config = ExperimentConfig( - experiment_name=experiment_name, - benchmark=benchmark_config, - launcher=launcher_config, - backend=backend_config, + benchmark_config = BenchmarkConfig( + name=benchmark_name, scenario=scenario_config, launcher=launcher_config, backend=backend_config ) - if is_experiment_conducted(experiment_config, PUSH_REPO_ID, subfolder): - LOGGER.info(f"Skipping experiment {experiment_name} with model {model} since it was already conducted") + if is_benchmark_conducted(benchmark_config, PUSH_REPO_ID, subfolder): + LOGGER.info(f"Skipping benchmark {benchmark_name} with model {model} since it was already conducted") return - experiment_config.push_to_hub(subfolder=subfolder, repo_id=PUSH_REPO_ID, private=True) + benchmark_config.push_to_hub(subfolder=subfolder, repo_id=PUSH_REPO_ID, private=True) try: - benchmark_report = launch(experiment_config) + LOGGER.info(f"Running benchmark {benchmark_name} with model {model}") + benchmark_report = Benchmark.launch(benchmark_config) benchmark_report.push_to_hub(subfolder=subfolder, repo_id=PUSH_REPO_ID, private=True) except Exception as error: - os.chdir(CWD) # TODO: figure our why this is happening - LOGGER.error(f"Experiment {experiment_name} failed with model {model}") - common_errors_reporter(error, LOGGER, subfolder, PUSH_REPO_ID) + LOGGER.error(f"Benchmark {benchmark_name} failed with model {model}") + valid_error, benchmark_report = errors_reporter(error) + LOGGER.error(benchmark_report.error, exc_info=True) + if valid_error: + benchmark_report.push_to_hub(subfolder=subfolder, repo_id=PUSH_REPO_ID, private=True) if __name__ == "__main__": - if CANONICAL_MODELS_ONLY: - models_attentions_weights = list(product(CANONICAL_MODELS_LIST, ATTENTION_COFIGS, WEIGHTS_CONFIGS.keys())) - print(f"Total number of canonical models experiments: {len(models_attentions_weights)}") - else: - models_attentions_weights = list(product(PRETRAINED_MODELS_LIST, ATTENTION_COFIGS, WEIGHTS_CONFIGS.keys())) - print(f"Total number of pretrained models experiments: {len(models_attentions_weights)}") + setup_logging(level="INFO", format_prefix="MAIN-PROCESS") + + models_attentions_weights = list( + product(CANONICAL_PRETRAINED_OPEN_LLM_LIST, ATTENTION_COFIGS, WEIGHTS_CONFIGS.keys()) + ) + + LOGGER.info( + f"Running a total of {len(models_attentions_weights)} benchmarks, " + f"with {len(CANONICAL_PRETRAINED_OPEN_LLM_LIST)} models, " + f"{len(ATTENTION_COFIGS)} attentions implementations" + f"and {len(WEIGHTS_CONFIGS)} weights configurations" + ) for model, attn_implementation, weights_config in models_attentions_weights: benchmark_cuda_pytorch(model, attn_implementation, weights_config) diff --git a/llm_perf/update_open_llm_leaderboard.py b/llm_perf/update_open_llm_leaderboard.py new file mode 100644 index 00000000..0ea0827e --- /dev/null +++ b/llm_perf/update_open_llm_leaderboard.py @@ -0,0 +1,42 @@ +import subprocess + +import pandas as pd +from huggingface_hub import create_repo, upload_file + +scrapping_script = """ +git clone https://github.com/Weyaxi/scrape-open-llm-leaderboard.git +pip install -r scrape-open-llm-leaderboard/requirements.txt +python scrape-open-llm-leaderboard/main.py +rm -rf scrape-open-llm-leaderboard +""" + + +def run_scrapper(): + subprocess.run(scrapping_script, shell=True) + + +def main(): + run_scrapper() + + open_llm_leaderboard = pd.read_csv("open-llm-leaderboard.csv") + + if len(open_llm_leaderboard) > 0: + create_repo( + repo_id="optimum-benchmark/open-llm-leaderboard", + repo_type="dataset", + exist_ok=True, + private=False, + ) + upload_file( + repo_id="optimum-benchmark/open-llm-leaderboard", + commit_message="Update open LLM leaderboard", + path_or_fileobj="open-llm-leaderboard.csv", + path_in_repo="open-llm-leaderboard.csv", + repo_type="dataset", + ) + else: + raise ValueError("No models found") + + +if __name__ == "__main__": + main() diff --git a/llm_perf/utils.py b/llm_perf/utils.py index 2e1dfbbc..200908b5 100644 --- a/llm_perf/utils.py +++ b/llm_perf/utils.py @@ -1,41 +1,170 @@ -from optimum_benchmark.benchmarks.report import BenchmarkReport +from typing import Tuple +import pandas as pd -def common_errors_reporter(error, logger, subfolder, push_repo_id): - benchmark_report = BenchmarkReport.from_targets(["decode", "prefill", "per_token", "error"]) +from optimum_benchmark.report import BenchmarkReport + +OPEN_LLM_LEADERBOARD = pd.read_csv("hf://datasets/optimum-benchmark/open-llm-leaderboard/open-llm-leaderboard.csv") + + +INPUT_SHAPES = {"batch_size": 1, "sequence_length": 256} +GENERATE_KWARGS = {"max_new_tokens": 64, "min_new_tokens": 64} + + +CANONICAL_ORGANIZATIONS = [ + # big companies + *["google", "facebook", "meta", "meta-llama", "microsoft", "Intel", "TencentARC", "Salesforce"], + # collectives + *["EleutherAI", "tiiuae", "NousResearch", "Open-Orca"], + # HF related + ["bigcode", "HuggingFaceH4", "huggyllama"], + # community members + ["teknium"], + # startups + *[ + "mistral-community", + "openai-community", + "togethercomputer", + "stabilityai", + "CohereForAI", + "databricks", + "mistralai", + "internlm", + "Upstage", + "xai-org", + "Phind", + "01-ai", + "Deci", + "Qwen", + ], +] + + +OPEN_LLM_LIST = OPEN_LLM_LEADERBOARD.drop_duplicates(subset=["Model"])["Model"].tolist() +PRETRAINED_OPEN_LLM_LIST = ( + OPEN_LLM_LEADERBOARD[OPEN_LLM_LEADERBOARD["Type"] == "pretrained"] + .drop_duplicates(subset=["Model"])["Model"] + .tolist() +) +CANONICAL_PRETRAINED_OPEN_LLM_LIST = sorted( + [model for model in PRETRAINED_OPEN_LLM_LIST if model.split("/")[0] in CANONICAL_ORGANIZATIONS] +) + +CANONICAL_PRETRAINED_OPEN_LLM_LIST = [ + "01-ai/Yi-34B", + "01-ai/Yi-6B", + "Deci/DeciCoder-1b", + "Deci/DeciLM-7B", + "EleutherAI/gpt-j-6b", + "EleutherAI/gpt-neo-1.3B", + "EleutherAI/gpt-neo-125m", + "EleutherAI/gpt-neo-2.7B", + "EleutherAI/gpt-neox-20b", + "EleutherAI/polyglot-ko-12.8b", + "EleutherAI/pythia-1.3b", + "EleutherAI/pythia-1.4b", + # "EleutherAI/pythia-1.4b-deduped", + "EleutherAI/pythia-12b", + # "EleutherAI/pythia-12b-deduped", + "EleutherAI/pythia-160m", + # "EleutherAI/pythia-160m-deduped", + # "EleutherAI/pythia-1b-deduped", + "EleutherAI/pythia-2.7b", + # "EleutherAI/pythia-2.8b-deduped", + "EleutherAI/pythia-410m", + # "EleutherAI/pythia-410m-deduped", + "EleutherAI/pythia-6.7b", + # "EleutherAI/pythia-6.9b-deduped", + "EleutherAI/pythia-70m", + # "EleutherAI/pythia-70m-deduped", + "Qwen/Qwen-14B", + "Qwen/Qwen-72B", + "Qwen/Qwen-7B", + "Qwen/Qwen1.5-0.5B", + "Qwen/Qwen1.5-1.8B", + "Qwen/Qwen1.5-110B", + "Qwen/Qwen1.5-14B", + "Qwen/Qwen1.5-32B", + "Qwen/Qwen1.5-4B", + "Qwen/Qwen1.5-72B", + "Qwen/Qwen1.5-7B", + # "Qwen/Qwen1.5-7B-Chat", + "Qwen/Qwen1.5-MoE-A2.7B", + "Qwen/Qwen2-beta-14B", + "Qwen/Qwen2-beta-72B", + "Salesforce/codegen-16B-nl", + # "Salesforce/codegen-6B-multi", + "Salesforce/codegen-6B-nl", + "TencentARC/Mistral_Pro_8B_v0.1", + "databricks/dbrx-base", + "facebook/opt-125m", + "facebook/opt-13b", + "facebook/opt-2.7b", + "facebook/opt-30b", + "facebook/opt-350m", + "facebook/opt-6.7b", + "facebook/opt-66b", + "facebook/xglm-4.5B", + "facebook/xglm-564M", + "facebook/xglm-7.5B", + "google/gemma-7b", + "google/recurrentgemma-2b", + "internlm/internlm-20b", + "internlm/internlm2-20b", + "meta-llama/Llama-2-13b-hf", + "meta-llama/Llama-2-7b-hf", + "meta-llama/Meta-Llama-3-8B", + "meta-llama/Meta-Llama-3-70B", + "microsoft/phi-1_5", + "microsoft/rho-math-1b-v0.1", + "mistralai/Mistral-7B-v0.1", + "mistralai/Mixtral-8x22B-v0.1", + "mistralai/Mixtral-8x7B-v0.1", + "openai-community/gpt2", + "openai-community/gpt2-large", + "stabilityai/stablelm-2-12b", + "stabilityai/stablelm-2-1_6b", + "stabilityai/stablelm-3b-4e1t", + "stabilityai/stablelm-base-alpha-3b", + "stabilityai/stablelm-base-alpha-7b", + # "stabilityai/stablelm-base-alpha-7b-v2", + "tiiuae/falcon-180B", + "tiiuae/falcon-40b", + "tiiuae/falcon-7b", + "tiiuae/falcon-rw-1b", + # "togethercomputer/RedPajama-INCITE-7B-Base", + "togethercomputer/RedPajama-INCITE-Base-3B-v1", + "togethercomputer/RedPajama-INCITE-Base-7B-v0.1", +] + + +def errors_reporter(error) -> Tuple[bool, BenchmarkReport]: + valid_error = True + benchmark_report = BenchmarkReport.from_list(["error"]) if "torch.cuda.OutOfMemoryError" in str(error): - logger.error("CUDA: Out of memory") benchmark_report.error = "CUDA: Out of memory" - benchmark_report.push_to_hub(subfolder=subfolder, repo_id=push_repo_id, private=True) elif "gptq" in str(error) and "assert outfeatures % 32 == 0" in str(error): - logger.error("GPTQ: assert outfeatures % 32 == 0") benchmark_report.error = "GPTQ: assert outfeatures % 32 == 0" - benchmark_report.push_to_hub(subfolder=subfolder, repo_id=push_repo_id, private=True) elif "gptq" in str(error) and "assert infeatures % self.group_size == 0" in str(error): - logger.error("GPTQ: assert infeatures % self.group_size == 0") benchmark_report.error = "GPTQ: assert infeatures % self.group_size == 0" - benchmark_report.push_to_hub(subfolder=subfolder, repo_id=push_repo_id, private=True) elif "support Flash Attention 2.0" in str(error): - logger.error("Flash Attention 2.0: not supported yet") benchmark_report.error = "Flash Attention 2.0: not supported yet" - benchmark_report.push_to_hub(subfolder=subfolder, repo_id=push_repo_id, private=True) elif "support an attention implementation through torch.nn.functional.scaled_dot_product_attention" in str(error): - logger.error("SDPA: not supported yet") benchmark_report.error = "SDPA: not supported yet" - benchmark_report.push_to_hub(subfolder=subfolder, repo_id=push_repo_id, private=True) elif "FlashAttention only support fp16 and bf16 data type" in str(error): - logger.error("FlashAttention: only support fp16 and bf16 data type") benchmark_report.error = "FlashAttention: only support fp16 and bf16 data type" - benchmark_report.push_to_hub(subfolder=subfolder, repo_id=push_repo_id, private=True) else: - logger.error(f"Unknown error: {error}") + valid_error = False + benchmark_report.error = f"Unknown error: {error}" + + return valid_error, benchmark_report -def is_experiment_conducted(experiment_config, push_repo_id, subfolder): +def is_benchmark_conducted(benchmark_config, push_repo_id, subfolder): try: - loaded_experiment_config = experiment_config.from_pretrained(repo_id=push_repo_id, subfolder=subfolder) - if loaded_experiment_config.to_dict() == experiment_config.to_dict(): + loaded_benchmark_config = benchmark_config.from_pretrained(repo_id=push_repo_id, subfolder=subfolder) + if loaded_benchmark_config.to_dict() == benchmark_config.to_dict(): BenchmarkReport.from_pretrained(repo_id=push_repo_id, subfolder=subfolder) return True except Exception: @@ -44,8 +173,8 @@ def is_experiment_conducted(experiment_config, push_repo_id, subfolder): return False -def is_experiment_not_supported(torch_dtype, attn_implementation): - if attn_implementation == "flash_attention_2" and torch_dtype == "float32": - return True +def is_benchmark_supported(weights_config, attn_implementation): + if attn_implementation == "flash_attention_2" and weights_config == "float32": + return False - return False + return True