diff --git a/examples/openwebtext/README.md b/examples/openwebtext/README.md index a35f0de..e8a27c6 100644 --- a/examples/openwebtext/README.md +++ b/examples/openwebtext/README.md @@ -1,16 +1,23 @@ -```bash -python analyze.py --factor_batch_size 32 \ - --train_batch_size 64 \ - --factor_strategy ekfac -``` +# OpenWebText & Llama-3-8B Example +This repository contains scripts for computing influence scores on the subset of OpenWebText dataset. +The pipeline is inspired by [the LoggIX](https://github.com/logix-project/logix/tree/main/examples/language_modeling). +Install the necessary packages: ```bash -torchrun --standalone --nnodes=1 --nproc-per-node=2 analyze.py --factor_batch_size 8 \ - --factor_strategy ekfac +pip install -r requirements.txt ``` +## Training + +We will use the pre-trained model [from HuggingFace](https://huggingface.co/meta-llama/Meta-Llama-3-8B). + +## Computing EKFAC Factors ```bash torchrun --standalone --nnodes=1 --nproc-per-node=4 fit_factors.py --factor_batch_size 4 -``` \ No newline at end of file +``` + + +The `generate.py` folder contains a code to generate response of the Llama-3-8B model given certain prompt. +I saved some prompt and completition pair to the directory `data/data.json`. \ No newline at end of file diff --git a/examples/openwebtext/fit_factors.py b/examples/openwebtext/fit_factors.py index d44d3b5..a3b4146 100644 --- a/examples/openwebtext/fit_factors.py +++ b/examples/openwebtext/fit_factors.py @@ -8,10 +8,7 @@ from torch import nn from transformers import default_data_collator -from examples.openwebtext.pipeline import ( - construct_llama3, - get_openwebtext_dataset, -) +from examples.openwebtext.pipeline import construct_llama3, get_openwebtext_dataset from examples.openwebtext.task import LanguageModelingTask from kronfluence.analyzer import Analyzer, prepare_model from kronfluence.utils.common.factor_arguments import ( @@ -75,19 +72,19 @@ def main(): profile=args.profile, ) # Configure parameters for DataLoader. - dataloader_kwargs = DataLoaderKwargs(collate_fn=default_data_collator) + dataloader_kwargs = DataLoaderKwargs(num_workers=4, collate_fn=default_data_collator, pin_memory=True) analyzer.set_dataloader_kwargs(dataloader_kwargs) factors_name = args.factor_strategy - factor_args = extreme_reduce_memory_factor_arguments(strategy=args.factor_strategy, - module_partitions=2, - dtype=torch.bfloat16) - factor_args.covariance_max_examples=4 + factor_args = extreme_reduce_memory_factor_arguments( + strategy=args.factor_strategy, module_partitions=1, dtype=torch.bfloat16 + ) + factor_args.covariance_module_partitions = 2 + factor_args.lambda_module_partitions = 4 analyzer.fit_all_factors( factors_name=factors_name, dataset=train_dataset, per_device_batch_size=args.factor_batch_size, - # per_device_batch_size=None, factor_args=factor_args, overwrite_output_dir=False, ) diff --git a/examples/openwebtext/pipeline.py b/examples/openwebtext/pipeline.py index 4a799eb..fda0fe0 100644 --- a/examples/openwebtext/pipeline.py +++ b/examples/openwebtext/pipeline.py @@ -39,8 +39,7 @@ def tokenize_function(examples): results = tokenizer(examples[text_column_name], truncation=True, padding=True, max_length=MAX_LENGTH) results["labels"] = results["input_ids"].copy() results["labels"] = [ - [-100 if token == tokenizer.pad_token_id else token for token in label] - for label in results["labels"] + [-100 if token == tokenizer.pad_token_id else token for token in label] for label in results["labels"] ] return results diff --git a/examples/openwebtext/task.py b/examples/openwebtext/task.py index 123ef4b..3e3c264 100644 --- a/examples/openwebtext/task.py +++ b/examples/openwebtext/task.py @@ -26,9 +26,7 @@ def compute_train_loss( if not sample: labels = batch["labels"] shift_labels = labels[..., 1:].contiguous() - summed_loss = F.cross_entropy( - logits, shift_labels.view(-1), reduction="sum", ignore_index=-100 - ) + summed_loss = F.cross_entropy(logits, shift_labels.view(-1), reduction="sum", ignore_index=-100) else: with torch.no_grad(): probs = torch.nn.functional.softmax(logits.detach(), dim=-1) @@ -55,6 +53,12 @@ def compute_measurement( def get_influence_tracked_modules(self) -> List[str]: total_modules = [] + for i in range(32): + total_modules.append(f"model.layers.{i}.self_attn.q_proj") + total_modules.append(f"model.layers.{i}.self_attn.k_proj") + total_modules.append(f"model.layers.{i}.self_attn.v_proj") + total_modules.append(f"model.layers.{i}.self_attn.o_proj") + for i in range(32): total_modules.append(f"model.layers.{i}.mlp.gate_proj") total_modules.append(f"model.layers.{i}.mlp.up_proj")