Skip to content

Commit

Permalink
Fix depreciated commands
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jul 9, 2024
1 parent 6de4afb commit 2e2802d
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 6 deletions.
10 changes: 10 additions & 0 deletions examples/glue/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,16 @@ This reduces computation time to about 30 minutes on an A100 (80GB) GPU.
----------------------------------------------------------------------------------------------------------------------------------
```

```bash
python analyze.py --dataset_name sst2 \
--query_batch_size 175 \
--train_batch_size 128 \
--checkpoint_dir ./checkpoints \
--factor_strategy ekfac \
--use_half_precision \
--query_gradient_rank 32
```

## Counterfactual Evaluation

Let's evaluate the impact of removing top positively influential training examples on query misclassification.
Expand Down
4 changes: 2 additions & 2 deletions examples/glue/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,8 @@ def main():
scores_name += "_half"
rank = args.query_gradient_rank if args.query_gradient_rank != -1 else None
if rank is not None:
score_args.query_gradient_rank = rank
score_args.num_query_gradient_accumulations = 10
score_args.query_gradient_low_rank = rank
score_args.query_gradient_accumulation_steps = 10
scores_name += f"_qlr{rank}"
analyzer.compute_pairwise_scores(
score_args=score_args,
Expand Down
3 changes: 2 additions & 1 deletion examples/imagenet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ To compute pairwise influence scores on 1000 query data points using the `ekfac`
```bash
python analyze.py --dataset_dir PATH_TO_IMAGENET \
--query_gradient_rank -1 \
--factor_batch_size 512 \
--query_batch_size 100 \
--train_batch_size 300 \
--train_batch_size 256 \
--factor_strategy ekfac
```

Expand Down
10 changes: 8 additions & 2 deletions examples/imagenet/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ def parse_args():
default=-1,
help="Rank for the low-rank query gradient approximation.",
)
parser.add_argument(
"--factor_batch_size",
type=int,
default=512,
help="Batch size for computing factors.",
)
parser.add_argument(
"--query_batch_size",
type=int,
Expand Down Expand Up @@ -85,7 +91,7 @@ def main():
)
# Configure parameters for DataLoader.
dataloader_kwargs = DataLoaderKwargs(
num_workers=4, pin_memory=True,
num_workers=4,
)
analyzer.set_dataloader_kwargs(dataloader_kwargs)

Expand All @@ -98,7 +104,7 @@ def main():
analyzer.fit_all_factors(
factors_name=factors_name,
dataset=train_dataset,
per_device_batch_size=None,
per_device_batch_size=args.factor_batch_size,
factor_args=factor_args,
overwrite_output_dir=False,
)
Expand Down
2 changes: 1 addition & 1 deletion examples/imagenet/ddp_analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def main():
)
# Configure parameters for DataLoader.
dataloader_kwargs = DataLoaderKwargs(
num_workers=4, pin_memory=True,
num_workers=4,
)
analyzer.set_dataloader_kwargs(dataloader_kwargs)

Expand Down

0 comments on commit 2e2802d

Please sign in to comment.