From 0744c4bc9b32e320ed1543673d69cba5783e18c7 Mon Sep 17 00:00:00 2001 From: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Date: Sun, 16 Feb 2025 14:59:45 -0800 Subject: [PATCH] TopLoc Proof Generation (#64) * pin yesterdays sglang * add toploc subpackage * add tokenize option to dataloader * add base64 proofs * sha util * add file info metric * add print * null proof fallback * unit tests * fix: null case dont hardcode * skip prefill to avoid the full cache case * fix: string demarcation * tokenize -> do_tokenization * move prime metric instantiation out of dataloader * kill a few people from the terrible startup time * Update src/genesys/toploc/__init__.py Co-authored-by: samsja <55492238+samsja@users.noreply.github.com> * use abs paths for imports * remove do_tokenization * remove toploc code * add toploc dependency * change import path * fix: do explicit tokenize and remove template bug * Update src/genesys/utils.py Co-authored-by: samsja <55492238+samsja@users.noreply.github.com> --------- Co-authored-by: samsja <55492238+samsja@users.noreply.github.com> --- pyproject.toml | 3 ++- src/genesys/data.py | 19 +++++++++++-------- src/genesys/generate.py | 32 ++++++++++++++++++++++---------- src/genesys/utils.py | 3 ++- uv.lock | 17 +++++++++++++++++ 5 files changed, 54 insertions(+), 20 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d477bd6..3ab2265 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,8 @@ dependencies = [ "google-cloud-storage", "tomli", "docker>=7.1.0", - "pynvml>=12.0.0" + "pynvml>=12.0.0", + "toploc>=0.0.2", ] [project.optional-dependencies] diff --git a/src/genesys/data.py b/src/genesys/data.py index da90d04..52071dc 100644 --- a/src/genesys/data.py +++ b/src/genesys/data.py @@ -43,7 +43,7 @@ class DataLoaderGenesys: Each dataset that is pass must have a "train" split and the content must be a list of dict with at least a "problem" and a "ground_truth" key. """ - def __init__(self, config: DataConfig, tokenizer: AutoTokenizer): + def __init__(self, config: DataConfig, tokenizer: AutoTokenizer, prime_metric: PrimeMetric): self.config = config self.paths = list(config.path.split(",")) @@ -92,7 +92,7 @@ def _add_column(dataset, path): for i, length in enumerate(self.dataset_lengths) ] - self.prime_metric = PrimeMetric(disable=not (config.prime_log), period=config.prime_log_freq) + self.prime_metric = prime_metric def _prepare_batch(self, batch: dict, dataset: str) -> tuple: batch = repeat_elements( @@ -103,14 +103,15 @@ def _prepare_batch(self, batch: dict, dataset: str) -> tuple: [ {"role": "user", "content": b["prompt"]}, {"role": "assistant", "content": "\n" + b["llm_response_first_time"]}, - {"role": "assistant", "content": ""}, # this message needs to be here so hf templating works, we're stripping it out again below + { + "role": "assistant", + "content": "", + }, # this message needs to be here so hf templating works, we're stripping it out again below ] for b in batch ] batch_inputs = self.tokenizer.apply_chat_template( - batch_messages, - tokenize=False, - continue_final_message=True + batch_messages, tokenize=False, continue_final_message=True ) unwanted_suffix = "<|end▁of▁sentence|><|Assistant|><|end▁of▁sentence|>" # strip out last message for i, inp in enumerate(batch_inputs): @@ -120,9 +121,11 @@ def _prepare_batch(self, batch: dict, dataset: str) -> tuple: batch_messages = [ [{"role": "user", "content": b["prompt"]}, {"role": "assistant", "content": "/n"}] for b in batch ] - - batch_inputs = self.tokenizer.apply_chat_template(batch_messages, tokenize=False, continue_final_message=True) + batch_inputs = self.tokenizer.apply_chat_template( + batch_messages, tokenize=False, continue_final_message=True + ) + batch_inputs = self.tokenizer(batch_inputs, add_special_tokens=False).input_ids return batch_inputs, batch def __iter__(self) -> Generator[tuple, None, None]: diff --git a/src/genesys/generate.py b/src/genesys/generate.py index 164e777..aadbd93 100644 --- a/src/genesys/generate.py +++ b/src/genesys/generate.py @@ -14,7 +14,9 @@ log, console, ) +from genesys.prime_metrics import PrimeMetric from genesys.data import DataConfig, DataLoaderGenesys +from toploc import build_proofs_base64, sha256sum class GenerateConfig(BaseConfig): @@ -46,6 +48,7 @@ def check_batch_size(self): def main(config: GenerateConfig): # Initial welcome table display_config_panel(console, config) + prime_metric = PrimeMetric(disable=not (config.data.prime_log), period=config.data.prime_log_freq) log("[bold yellow] Loading model and initializing pipeline...[/]") @@ -56,11 +59,7 @@ def main(config: GenerateConfig): assert gcp_credentials is not None, "the GCP_CREDENTIALS_BASE64 environment variable is not set" if not os.path.exists(config.path_output): os.makedirs(config.path_output) - gcp_bucket = ( - GcpBucket(config.gcp_bucket, gcp_credentials) - if config.gcp_bucket is not None - else None - ) + gcp_bucket = GcpBucket(config.gcp_bucket, gcp_credentials) if config.gcp_bucket is not None else None if config.pre_download_retry > 0: log("[cyan] Pre-downloading model...[/]") @@ -68,13 +67,18 @@ def main(config: GenerateConfig): log("[cyan] Loading model and Engine...[/]") - llm = sgl.Engine(model_path=config.name_model, tp_size=config.num_gpus) + llm = sgl.Engine( + model_path=config.name_model, tp_size=config.num_gpus, return_hidden_states=True, skip_tokenizer_init=True + ) log("[cyan] Loading tokenizer...[/]") tokenizer = AutoTokenizer.from_pretrained(config.name_model) + tokenizer.chat_template = tokenizer.chat_template.replace( + "{% if '' in content %}{% set content = content.split('')[-1] %}{% endif %}", "" + ) log("[cyan] Loading dataloader...[/]") - dataloader = DataLoaderGenesys(config.data, tokenizer=tokenizer) + dataloader = DataLoaderGenesys(config.data, tokenizer=tokenizer, prime_metric=prime_metric) machine_info = get_machine_info() log("[bold green]✨ Setup complete! Starting generation...[/]") @@ -85,13 +89,18 @@ def main(config: GenerateConfig): total_samples = 0 for batch_inputs, batch in dataloader: - responses = llm.generate(batch_inputs, sampling_params) - for batch_element, response in zip(batch, responses): - batch_element["llm_response"] = response["text"] + responses = llm.generate(input_ids=batch_inputs, sampling_params=sampling_params) + for batch_input, batch_element, response in zip(batch_inputs, batch, responses): + batch_element["llm_response"] = tokenizer.decode(response["token_ids"], skip_special_tokens=True) batch_element["response_id"] = f"{batch_element['problem_id']}_{generate_short_id()}" batch_element["model_name"] = config.name_model batch_element["generation_config"] = sampling_params batch_element["machine_info"] = machine_info + batch_element["input_ids"] = batch_input + batch_element["output_ids"] = response["token_ids"] + batch_element["proof"] = "".join( + build_proofs_base64(response["meta_info"]["hidden_states"], 32, 128, skip_prefill=True) + ) all_results.append(batch_element) total_samples += len(batch) @@ -99,6 +108,9 @@ def main(config: GenerateConfig): file_name = f"{config.out_file_prefix}_{uuid.uuid4()}.jsonl" file = os.path.join(config.path_output, file_name) save_batch_results(all_results, file, gcp_bucket) + file_sha = sha256sum(file) + prime_metric.log_prime({"file_sha": file_sha, "file_name": file_name}) + log(f"[bold green]✨ Saved {len(all_results)} samples to {file} with sha {file_sha or 'NA'}[/]") all_results = [] log(f"[bold green]✨ Generation complete! Total samples: {total_samples}[/]") diff --git a/src/genesys/utils.py b/src/genesys/utils.py index 88f8c94..57460a0 100644 --- a/src/genesys/utils.py +++ b/src/genesys/utils.py @@ -20,6 +20,7 @@ from rich.console import Console from huggingface_hub import snapshot_download from datasets import load_dataset +from pathlib import Path class GcpBucket: @@ -75,7 +76,7 @@ def __del__(self): self.worker_thread.join() -def save_batch_results(batch_results, results_file, gcp_bucket: GcpBucket | None = None): +def save_batch_results(batch_results, results_file: str | Path, gcp_bucket: GcpBucket | None = None): # Save locally first with open(results_file, "a") as f: for result in batch_results: diff --git a/uv.lock b/uv.lock index 3d46d8c..8f2ff1f 100644 --- a/uv.lock +++ b/uv.lock @@ -648,6 +648,7 @@ dependencies = [ { name = "pynvml" }, { name = "rich" }, { name = "tomli" }, + { name = "toploc" }, { name = "torch" }, { name = "transformers" }, ] @@ -690,6 +691,7 @@ requires-dist = [ { name = "sglang", extras = ["srt"], marker = "extra == 'sglang'", specifier = ">=0.4.3" }, { name = "sympy", marker = "extra == 'sglang'" }, { name = "tomli" }, + { name = "toploc", specifier = ">=0.0.2" }, { name = "torch", specifier = ">=2.4.1" }, { name = "transformers", specifier = ">=4.44.2" }, { name = "uvicorn", marker = "extra == 'sglang'" }, @@ -3366,6 +3368,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6e/c2/61d3e0f47e2b74ef40a68b9e6ad5984f6241a942f7cd3bbfbdbd03861ea9/tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc", size = 14257 }, ] +[[package]] +name = "toploc" +version = "0.0.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "torch" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/60/0d/a230f5252a2f1c3872311ad24b6c88c82c24e5da5e01ffb4473e8000fc4a/toploc-0.0.2.tar.gz", hash = "sha256:c6f2270c90af73697a5c7d2cf575fd84b2164fdc19846931e4be6e09749a22c2", size = 13103 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1a/8a/9c56d952e50f25977960e373db4535c59caa676cbb448a51f00d55aa61ea/toploc-0.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59e5e0438589d72ead9c72e44ff43b65ae200d4f3f094fa5e69ee0afacc42741", size = 5395245 }, + { url = "https://files.pythonhosted.org/packages/7e/91/280e9a21e0881432b204a4f0d3b4742a6c1e7e816177661d27f464e05369/toploc-0.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:551319b4170d1f71a3d48cadd11bfba44572df2bcf9ced7a4da18844a8797e97", size = 5421389 }, + { url = "https://files.pythonhosted.org/packages/96/75/4612d9f1c37af9474a90b3d2205cf9c0000670c79cea32ae18b126e0e8a3/toploc-0.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8927e1db712a461b1226b9a47585182df0f6198fdcab89b68bfd076479281401", size = 5424790 }, +] + [[package]] name = "torch" version = "2.5.1"