From 05ce8eb1e6f5a844d2592ecb73e1cf8e3b870710 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Wed, 26 Jun 2024 02:57:36 -0400 Subject: [PATCH] minor --- examples/glue/README.md | 9 +++++---- examples/glue/analyze.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/glue/README.md b/examples/glue/README.md index 9b4671a..ac6e182 100644 --- a/examples/glue/README.md +++ b/examples/glue/README.md @@ -63,7 +63,7 @@ Can we remove top positively influential training examples to make some queries selects correctly classified query data point, removes top-k positively influential training samples, and retrain the network with the modified dataset to see if that query data point gets misclassified. -We first need to compute pairwise influence scores for the `RTE` dataset: +We first need to compute pairwise influence scores for the `RTE` dataset (A6000 GPU was used to run these experiments): ```bash python train.py --dataset_name rte \ @@ -76,13 +76,13 @@ python train.py --dataset_name rte \ --seed 1004 python analyze.py --dataset_name rte \ - --query_batch_size 175 \ + --query_batch_size 70 \ --train_batch_size 128 \ --checkpoint_dir ./checkpoints \ --factor_strategy ekfac python analyze.py --dataset_name rte \ - --query_batch_size 175 \ + --query_batch_size 139 \ --train_batch_size 128 \ --checkpoint_dir ./checkpoints \ --factor_strategy identity @@ -97,7 +97,8 @@ python analyze.py --dataset_name rte \ ## Evaluating Linear Datamodeling Score The `evaluate_lds.py` script computes the [linear datamodeling score (LDS)](https://arxiv.org/abs/2303.14186). It measures the LDS obtained by -retraining the network 500 times with different subsets of the dataset (5 repeats and 100 masks). By running `evaludate_lds.py`, we obtain `xx` LDS (we get `xx` LDS with the half precision). +retraining the network 500 times with different subsets of the dataset (5 repeats and 100 masks). +By running `evaludate_lds.py`, we obtain `xx` LDS (we get `xx` LDS with the half precision). The script also includes functionality to print out top influential sequences for a given query. diff --git a/examples/glue/analyze.py b/examples/glue/analyze.py index df610f5..a824d5d 100644 --- a/examples/glue/analyze.py +++ b/examples/glue/analyze.py @@ -176,7 +176,7 @@ def main(): dataset=train_dataset, per_device_batch_size=None, factor_args=factor_args, - overwrite_output_dir=True, + overwrite_output_dir=False, initial_per_device_batch_size_attempt=512, )