Skip to content

Commit

Permalink
Fix task formulation
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jul 10, 2024
1 parent 9c3f89c commit 395bbc1
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 12 deletions.
5 changes: 3 additions & 2 deletions examples/dailymail/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ This will fine-tune the model using the specified hyperparameters and save the f
To calculate pairwise influence scores on 10 query data points using `ekfac`, run:

```bash
python analyze.py --query_batch_size 10 \
python analyze.py --factor_batch_size 64 \
--query_batch_size 10 \
--train_batch_size 128 \
--use_half_precision \
--checkpoint_dir ./checkpoints \
Expand All @@ -43,7 +44,7 @@ Alternative options for `factor_strategy` include `identity`, `diagonal`, and `k

## Inspecting Top Influential Sequences

The `inspect.py` script prints top influential sequences for a given query.
The `inspect_examples.py` script prints top influential sequences for a given query.

```
Query Data Example:
Expand Down
16 changes: 8 additions & 8 deletions examples/dailymail/analyze.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import logging
import os
from typing import Dict, List, Optional
from typing import Dict, List

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -31,7 +31,7 @@


def parse_args():
parser = argparse.ArgumentParser(description="Influence analysis on DailyMail dataset.")
parser = argparse.ArgumentParser(description="Influence analysis on CNN DailyMail dataset.")

parser.add_argument(
"--checkpoint_dir",
Expand All @@ -43,7 +43,7 @@ def parse_args():
parser.add_argument(
"--factor_strategy",
type=str,
default="identity",
default="ekfac",
help="Strategy to compute influence factors.",
)
parser.add_argument(
Expand All @@ -67,19 +67,19 @@ def parse_args():
parser.add_argument(
"--factor_batch_size",
type=int,
default=None,
default=64,
help="Batch size for computing influence factors.",
)
parser.add_argument(
"--query_batch_size",
type=int,
default=2,
default=10,
help="Batch size for computing query gradients.",
)
parser.add_argument(
"--train_batch_size",
type=int,
default=1,
default=128,
help="Batch size for computing training gradients.",
)
parser.add_argument(
Expand Down Expand Up @@ -147,7 +147,7 @@ def compute_measurement(
masks = batch["labels"].view(-1) != -100
return -margins[masks].sum()

def tracked_modules(self) -> List[str]:
def get_influence_tracked_modules(self) -> List[str]:
total_modules = []

# Add attention layers:
Expand Down Expand Up @@ -177,7 +177,7 @@ def tracked_modules(self) -> List[str]:

return total_modules

def get_attention_mask(self, batch: BATCH_TYPE) -> Optional[torch.Tensor]:
def get_attention_mask(self, batch: BATCH_TYPE) -> torch.Tensor:
return batch["attention_mask"]


Expand Down
2 changes: 1 addition & 1 deletion examples/dailymail/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def preprocess_function(examples):
batched=True,
num_proc=None,
remove_columns=column_names,
# load_from_cache_file=True,
load_from_cache_file=True,
desc="Running tokenizer on dataset.",
)
ds = train_dataset
Expand Down
2 changes: 1 addition & 1 deletion examples/dailymail/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@


def parse_args():
parser = argparse.ArgumentParser(description="Train seq2seq models on DailyMail dataset.")
parser = argparse.ArgumentParser(description="Train seq2seq models on CNN DailyMail dataset.")

parser.add_argument(
"--train_batch_size",
Expand Down

0 comments on commit 395bbc1

Please sign in to comment.