diff --git a/examples/glue/README.md b/examples/glue/README.md index 71cecc3..2b3404d 100644 --- a/examples/glue/README.md +++ b/examples/glue/README.md @@ -24,9 +24,9 @@ python train.py --dataset_name sst2 \ To obtain a pairwise influence scores on maximum of 2000 query data points using `ekfac`, run the following command: ```bash -python analyze.py --query_batch_size 8 \ +python analyze.py --dataset_name sst2 \ + --query_batch_size 8 \ --train_batch_size 32 \ - --dataset_dir ./data \ --checkpoint_dir ./checkpoints \ --factor_strategy ekfac ``` diff --git a/examples/glue/analyze.py b/examples/glue/analyze.py index 53b630f..7416520 100644 --- a/examples/glue/analyze.py +++ b/examples/glue/analyze.py @@ -117,11 +117,11 @@ def main(): logging.basicConfig(level=logging.INFO) train_dataset = get_glue_dataset( - data_name=args.data_name, + data_name=args.dataset_name, split="eval_train", ) eval_dataset = get_glue_dataset( - data_name=args.data_name, + data_name=args.dataset_name, split="valid", )