Skip to content

Commit

Permalink
add openwebtext
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jul 10, 2024
1 parent 7209dac commit 7374ddd
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 23 deletions.
23 changes: 15 additions & 8 deletions examples/openwebtext/README.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
```bash
python analyze.py --factor_batch_size 32 \
--train_batch_size 64 \
--factor_strategy ekfac
```
# OpenWebText & Llama-3-8B Example

This repository contains scripts for computing influence scores on the subset of OpenWebText dataset.
The pipeline is inspired by [the LoggIX](https://github.com/logix-project/logix/tree/main/examples/language_modeling).
Install the necessary packages:

```bash
torchrun --standalone --nnodes=1 --nproc-per-node=2 analyze.py --factor_batch_size 8 \
--factor_strategy ekfac
pip install -r requirements.txt
```

## Training

We will use the pre-trained model [from HuggingFace](https://huggingface.co/meta-llama/Meta-Llama-3-8B).

## Computing EKFAC Factors

```bash
torchrun --standalone --nnodes=1 --nproc-per-node=4 fit_factors.py --factor_batch_size 4
```
```


The `generate.py` folder contains a code to generate response of the Llama-3-8B model given certain prompt.
I saved some prompt and completition pair to the directory `data/data.json`.
17 changes: 7 additions & 10 deletions examples/openwebtext/fit_factors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@
from torch import nn
from transformers import default_data_collator

from examples.openwebtext.pipeline import (
construct_llama3,
get_openwebtext_dataset,
)
from examples.openwebtext.pipeline import construct_llama3, get_openwebtext_dataset
from examples.openwebtext.task import LanguageModelingTask
from kronfluence.analyzer import Analyzer, prepare_model
from kronfluence.utils.common.factor_arguments import (
Expand Down Expand Up @@ -75,19 +72,19 @@ def main():
profile=args.profile,
)
# Configure parameters for DataLoader.
dataloader_kwargs = DataLoaderKwargs(collate_fn=default_data_collator)
dataloader_kwargs = DataLoaderKwargs(num_workers=4, collate_fn=default_data_collator, pin_memory=True)
analyzer.set_dataloader_kwargs(dataloader_kwargs)

factors_name = args.factor_strategy
factor_args = extreme_reduce_memory_factor_arguments(strategy=args.factor_strategy,
module_partitions=2,
dtype=torch.bfloat16)
factor_args.covariance_max_examples=4
factor_args = extreme_reduce_memory_factor_arguments(
strategy=args.factor_strategy, module_partitions=1, dtype=torch.bfloat16
)
factor_args.covariance_module_partitions = 2
factor_args.lambda_module_partitions = 4
analyzer.fit_all_factors(
factors_name=factors_name,
dataset=train_dataset,
per_device_batch_size=args.factor_batch_size,
# per_device_batch_size=None,
factor_args=factor_args,
overwrite_output_dir=False,
)
Expand Down
3 changes: 1 addition & 2 deletions examples/openwebtext/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ def tokenize_function(examples):
results = tokenizer(examples[text_column_name], truncation=True, padding=True, max_length=MAX_LENGTH)
results["labels"] = results["input_ids"].copy()
results["labels"] = [
[-100 if token == tokenizer.pad_token_id else token for token in label]
for label in results["labels"]
[-100 if token == tokenizer.pad_token_id else token for token in label] for label in results["labels"]
]
return results

Expand Down
10 changes: 7 additions & 3 deletions examples/openwebtext/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ def compute_train_loss(
if not sample:
labels = batch["labels"]
shift_labels = labels[..., 1:].contiguous()
summed_loss = F.cross_entropy(
logits, shift_labels.view(-1), reduction="sum", ignore_index=-100
)
summed_loss = F.cross_entropy(logits, shift_labels.view(-1), reduction="sum", ignore_index=-100)
else:
with torch.no_grad():
probs = torch.nn.functional.softmax(logits.detach(), dim=-1)
Expand All @@ -55,6 +53,12 @@ def compute_measurement(
def get_influence_tracked_modules(self) -> List[str]:
total_modules = []

for i in range(32):
total_modules.append(f"model.layers.{i}.self_attn.q_proj")
total_modules.append(f"model.layers.{i}.self_attn.k_proj")
total_modules.append(f"model.layers.{i}.self_attn.v_proj")
total_modules.append(f"model.layers.{i}.self_attn.o_proj")

for i in range(32):
total_modules.append(f"model.layers.{i}.mlp.gate_proj")
total_modules.append(f"model.layers.{i}.mlp.up_proj")
Expand Down

0 comments on commit 7374ddd

Please sign in to comment.