Skip to content

Commit

Permalink
TopLoc Proof Generation (#64)
Browse files Browse the repository at this point in the history
* pin yesterdays sglang

* add toploc subpackage

* add tokenize option to dataloader

* add base64 proofs

* sha util

* add file info metric

* add print

* null proof fallback

* unit tests

* fix: null case dont hardcode

* skip prefill to avoid the full cache case

* fix: string demarcation

* tokenize -> do_tokenization

* move prime metric instantiation out of dataloader

* kill a few people from the terrible startup time

* Update src/genesys/toploc/__init__.py

Co-authored-by: samsja <[email protected]>

* use abs paths for imports

* remove do_tokenization

* remove toploc code

* add toploc dependency

* change import path

* fix: do explicit tokenize and remove template bug

* Update src/genesys/utils.py

Co-authored-by: samsja <[email protected]>

---------

Co-authored-by: samsja <[email protected]>
  • Loading branch information
Jackmin801 and samsja authored Feb 16, 2025
1 parent e4f7475 commit 0744c4b
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 20 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ dependencies = [
"google-cloud-storage",
"tomli",
"docker>=7.1.0",
"pynvml>=12.0.0"
"pynvml>=12.0.0",
"toploc>=0.0.2",
]

[project.optional-dependencies]
Expand Down
19 changes: 11 additions & 8 deletions src/genesys/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class DataLoaderGenesys:
Each dataset that is pass must have a "train" split and the content must be a list of dict with at least a "problem" and a "ground_truth" key.
"""

def __init__(self, config: DataConfig, tokenizer: AutoTokenizer):
def __init__(self, config: DataConfig, tokenizer: AutoTokenizer, prime_metric: PrimeMetric):
self.config = config

self.paths = list(config.path.split(","))
Expand Down Expand Up @@ -92,7 +92,7 @@ def _add_column(dataset, path):
for i, length in enumerate(self.dataset_lengths)
]

self.prime_metric = PrimeMetric(disable=not (config.prime_log), period=config.prime_log_freq)
self.prime_metric = prime_metric

def _prepare_batch(self, batch: dict, dataset: str) -> tuple:
batch = repeat_elements(
Expand All @@ -103,14 +103,15 @@ def _prepare_batch(self, batch: dict, dataset: str) -> tuple:
[
{"role": "user", "content": b["prompt"]},
{"role": "assistant", "content": "<think>\n" + b["llm_response_first_time"]},
{"role": "assistant", "content": ""}, # this message needs to be here so hf templating works, we're stripping it out again below
{
"role": "assistant",
"content": "",
}, # this message needs to be here so hf templating works, we're stripping it out again below
]
for b in batch
]
batch_inputs = self.tokenizer.apply_chat_template(
batch_messages,
tokenize=False,
continue_final_message=True
batch_messages, tokenize=False, continue_final_message=True
)
unwanted_suffix = "<|end▁of▁sentence|><|Assistant|><|end▁of▁sentence|>" # strip out last message
for i, inp in enumerate(batch_inputs):
Expand All @@ -120,9 +121,11 @@ def _prepare_batch(self, batch: dict, dataset: str) -> tuple:
batch_messages = [
[{"role": "user", "content": b["prompt"]}, {"role": "assistant", "content": "<think>/n"}] for b in batch
]

batch_inputs = self.tokenizer.apply_chat_template(batch_messages, tokenize=False, continue_final_message=True)
batch_inputs = self.tokenizer.apply_chat_template(
batch_messages, tokenize=False, continue_final_message=True
)

batch_inputs = self.tokenizer(batch_inputs, add_special_tokens=False).input_ids
return batch_inputs, batch

def __iter__(self) -> Generator[tuple, None, None]:
Expand Down
32 changes: 22 additions & 10 deletions src/genesys/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
log,
console,
)
from genesys.prime_metrics import PrimeMetric
from genesys.data import DataConfig, DataLoaderGenesys
from toploc import build_proofs_base64, sha256sum


class GenerateConfig(BaseConfig):
Expand Down Expand Up @@ -46,6 +48,7 @@ def check_batch_size(self):
def main(config: GenerateConfig):
# Initial welcome table
display_config_panel(console, config)
prime_metric = PrimeMetric(disable=not (config.data.prime_log), period=config.data.prime_log_freq)

log("[bold yellow] Loading model and initializing pipeline...[/]")

Expand All @@ -56,25 +59,26 @@ def main(config: GenerateConfig):
assert gcp_credentials is not None, "the GCP_CREDENTIALS_BASE64 environment variable is not set"
if not os.path.exists(config.path_output):
os.makedirs(config.path_output)
gcp_bucket = (
GcpBucket(config.gcp_bucket, gcp_credentials)
if config.gcp_bucket is not None
else None
)
gcp_bucket = GcpBucket(config.gcp_bucket, gcp_credentials) if config.gcp_bucket is not None else None

if config.pre_download_retry > 0:
log("[cyan] Pre-downloading model...[/]")
download_model(config.name_model, config.pre_download_retry)

log("[cyan] Loading model and Engine...[/]")

llm = sgl.Engine(model_path=config.name_model, tp_size=config.num_gpus)
llm = sgl.Engine(
model_path=config.name_model, tp_size=config.num_gpus, return_hidden_states=True, skip_tokenizer_init=True
)

log("[cyan] Loading tokenizer...[/]")
tokenizer = AutoTokenizer.from_pretrained(config.name_model)
tokenizer.chat_template = tokenizer.chat_template.replace(
"{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}", ""
)

log("[cyan] Loading dataloader...[/]")
dataloader = DataLoaderGenesys(config.data, tokenizer=tokenizer)
dataloader = DataLoaderGenesys(config.data, tokenizer=tokenizer, prime_metric=prime_metric)
machine_info = get_machine_info()

log("[bold green]✨ Setup complete! Starting generation...[/]")
Expand All @@ -85,20 +89,28 @@ def main(config: GenerateConfig):
total_samples = 0

for batch_inputs, batch in dataloader:
responses = llm.generate(batch_inputs, sampling_params)
for batch_element, response in zip(batch, responses):
batch_element["llm_response"] = response["text"]
responses = llm.generate(input_ids=batch_inputs, sampling_params=sampling_params)
for batch_input, batch_element, response in zip(batch_inputs, batch, responses):
batch_element["llm_response"] = tokenizer.decode(response["token_ids"], skip_special_tokens=True)
batch_element["response_id"] = f"{batch_element['problem_id']}_{generate_short_id()}"
batch_element["model_name"] = config.name_model
batch_element["generation_config"] = sampling_params
batch_element["machine_info"] = machine_info
batch_element["input_ids"] = batch_input
batch_element["output_ids"] = response["token_ids"]
batch_element["proof"] = "".join(
build_proofs_base64(response["meta_info"]["hidden_states"], 32, 128, skip_prefill=True)
)
all_results.append(batch_element)
total_samples += len(batch)

if len(all_results) >= config.sample_per_file:
file_name = f"{config.out_file_prefix}_{uuid.uuid4()}.jsonl"
file = os.path.join(config.path_output, file_name)
save_batch_results(all_results, file, gcp_bucket)
file_sha = sha256sum(file)
prime_metric.log_prime({"file_sha": file_sha, "file_name": file_name})
log(f"[bold green]✨ Saved {len(all_results)} samples to {file} with sha {file_sha or 'NA'}[/]")
all_results = []

log(f"[bold green]✨ Generation complete! Total samples: {total_samples}[/]")
Expand Down
3 changes: 2 additions & 1 deletion src/genesys/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from rich.console import Console
from huggingface_hub import snapshot_download
from datasets import load_dataset
from pathlib import Path


class GcpBucket:
Expand Down Expand Up @@ -75,7 +76,7 @@ def __del__(self):
self.worker_thread.join()


def save_batch_results(batch_results, results_file, gcp_bucket: GcpBucket | None = None):
def save_batch_results(batch_results, results_file: str | Path, gcp_bucket: GcpBucket | None = None):
# Save locally first
with open(results_file, "a") as f:
for result in batch_results:
Expand Down
17 changes: 17 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 0744c4b

Please sign in to comment.