Thank you for your interest in our work! This is the original implementation of our ACL 2024 paper, "TaSL: Continual Dialog State Tracking via Task Skill Localization and Consolidation", and includes the methods proposed in our latest extended work, "TaSL: Task Skill Localization and Consolidation for Language Model Continual Learning."
conda create -n TaSL python=3.8
conda activate TaSL
pip install -r requirements.txt
The preprocessed SGD dataset for Continual DST is provided in the "/data" folder. If you are interested in the pre-processing, please check utils/preprocess.py
and utils/dataloader.py
at here.
Additionally, we now support two newly introduced datasets from our journal extension: Long Sequence Benchmark and SuperNI Benchmark.
For the four different backbone models, you can download they from the following links at huggingface:
Then replace the corresponding files in the Transformers package with trainer.py
and trainer_seq2seq.py
, which have modified the source code to add our importance-aware skill localization method.
We conducted experiments on four different student models:
./scripts/run_train_TaSL_LLaMA7B.sh
./scripts/run_train_TaSL_t5.sh
- --model_path: replace the position of various t5 models.
For LLaMA-7B, we use LoRA to accelerate the speed of fine-tuning process. At the end of training, the fine-tuned weights will be stored in $checkpoint_files
. And the importance distribution of skill units will be stored in $ipt_file
.
The code then automatically implements the fine-grained skill consolidation strategy (skill_consolidation.py
).
In our extended work, we introduce two new files:
finetune_TasLoRA.py
: This file implements LoRA-tailored skill units and introduces the Second-Order Gradient Approximation Metric.skill_consolidation_TasLoRA.py
: This utilizes the Adaptive Weighted Consolidation technique.
Three metrics are used to measure the performance of our model for Continual Learning:
./scripts/run_generate_avgJGA.sh
./scripts/run_generate_fwt.sh
./scripts/run_generate_bwt.sh
After inference, the generated prediction results will be stored at \output
folder.
Then you can calculate three metrics by running
./eval_avgJGA.py
./eval_fwt.py
./eval_bwt.py