Skip to content

WoodScene/TaSL

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Task Skill Localization and Consolidation (TaSL)

Thank you for your interest in our work! This is the original implementation of our ACL 2024 paper, "TaSL: Continual Dialog State Tracking via Task Skill Localization and Consolidation", and includes the methods proposed in our latest extended work, "TaSL: Task Skill Localization and Consolidation for Language Model Continual Learning."

Local Setup

conda create -n TaSL python=3.8
conda activate TaSL
pip install -r requirements.txt

Step 1. Preliminary Preparation

The preprocessed SGD dataset for Continual DST is provided in the "/data" folder. If you are interested in the pre-processing, please check utils/preprocess.py and utils/dataloader.py at here.

Additionally, we now support two newly introduced datasets from our journal extension: Long Sequence Benchmark and SuperNI Benchmark.

For the four different backbone models, you can download they from the following links at huggingface:

Then replace the corresponding files in the Transformers package with trainer.py and trainer_seq2seq.py, which have modified the source code to add our importance-aware skill localization method.

Step 2. Training (TaSL)

We conducted experiments on four different student models:

LLaMA-7B (finetune_ContinualDST_LLaMA7B.py)

./scripts/run_train_TaSL_LLaMA7B.sh

T5 Series Models (finetune_ContinualDST_T5XL.py)

./scripts/run_train_TaSL_t5.sh
  • --model_path: replace the position of various t5 models.

For LLaMA-7B, we use LoRA to accelerate the speed of fine-tuning process. At the end of training, the fine-tuned weights will be stored in $checkpoint_files. And the importance distribution of skill units will be stored in $ipt_file.

The code then automatically implements the fine-grained skill consolidation strategy (skill_consolidation.py).

Step 2. Training (TasLoRA)

In our extended work, we introduce two new files:

  • finetune_TasLoRA.py: This file implements LoRA-tailored skill units and introduces the Second-Order Gradient Approximation Metric.
  • skill_consolidation_TasLoRA.py: This utilizes the Adaptive Weighted Consolidation technique.

Step 3. Inference

Three metrics are used to measure the performance of our model for Continual Learning:

Avg.JGA

./scripts/run_generate_avgJGA.sh

Forward Transfer (FWT)

./scripts/run_generate_fwt.sh

Backward Transfer (BWT)

./scripts/run_generate_bwt.sh

After inference, the generated prediction results will be stored at \output folder.

Step 4. Evaluation

Then you can calculate three metrics by running

./eval_avgJGA.py
./eval_fwt.py
./eval_bwt.py

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published