Skip to content

Commit

Permalink
Merge pull request #33 from ddlBoJack/main
Browse files Browse the repository at this point in the history
sync
  • Loading branch information
ddlBoJack authored Jan 25, 2024
2 parents 5a86bfb + 6d30313 commit 73ea14d
Show file tree
Hide file tree
Showing 19 changed files with 436 additions and 695 deletions.
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,6 @@ transformers>=4.31.0
sentencepiece
py7zr
scipy
optimum
optimum
wandb
hydra-core>=1.3.2
111 changes: 111 additions & 0 deletions scripts/conf/asr_vicuna_lora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@

model_config:
llm_name: "vicuna-13b-v1.5"
llm_path: "PATH/to/LLAMA/7B"
llm_dim: 4096
encoder_name: null
encoder_ds_rate: 2
encoder_path: null
encoder_dim: 1280
encoder_projector: "linear"
encoder_projector_ds_rate: 5

DMODEL: 512
FRONTEND_DMODEL: 1024 #这个是专门指moco的
TX_ATTENTION_HEADS: 8
TX_NUM_LAYERS: 6
PE_MAX_LENGTH: 500
AUDIO_FEATURE_SIZE: 1024
VIDEO_FEATURE_SIZE: 2048
TX_FEEDFORWARD_DIM: 2048
TX_DROPOUT: 0.1
CHAR_NUM_CLASSES: 40

WORD_NUM_CLASSES: 500
FRAME_LENGTH: 29
MOCO_FRONTEND_FILE: "/nfs/yangguanrou.ygr/AVSR/pretrain_model/moco_frontend.pt" #"/home/oss/yangguanrou.ygr/AVSR/pretrain_model/moco_frontend.pt"
WAV2VEC_FILE: "/nfs/yangguanrou.ygr/AVSR/pretrain_model/wav2vec_vox_new.pt" #"/home/oss/yangguanrou.ygr/AVSR/pretrain_model/wav2vec_vox_new.pt"
MAIN_REQ_INPUT_LENGTH: int = 80
modal: "AV"
TRAIN_LRS3_MODEL_FILE: "/nfs/yangguanrou.ygr/AVSR/train-step_0108-wer_0.058.ckpt" # "/home/oss/yangguanrou.ygr/AVSR/train-step_0108-wer_0.058.ckpt" #单一模态是这个
TRAINED_AO_FILE: "/nfs/yangguanrou.ygr/AVSR/check/train-step_0604-wer_0.054.ckpt" #"/home/oss/yangguanrou.ygr/AVSR/check/train-step_0604-wer_0.054.ckpt"
TRAINED_VO_FILE: "/nfs/yangguanrou.ygr/AVSR/check/train-step_1191-wer_0.674.ckpt" #"/home/oss/yangguanrou.ygr/AVSR/check/train-step_1191-wer_0.674.ckpt"


train_config:
model_name: "PATH/to/LLAMA/7B"
enable_fsdp: false
low_cpu_fsdp: false
run_validation: true
batch_size_training: 4
batching_strategy: "packing" #alternative: padding
context_length: 4096
gradient_accumulation_steps: 1
num_epochs: 3
num_workers_dataloader: 1
lr: 1e-4
weight_decay: 0.0
gamma: 0.85
seed: 42
use_fp16: false
mixed_precision: true
val_batch_size: 1

use_peft: false
peft_config:
peft_method: "lora" # None , llama_adapter, prefix
r: 8
lora_alpha: 32
target_modules: [ "q_proj", "v_proj" ]
bias: null
task_type: "CAUSAL_LM"
lora_dropout: 0.05
inference_mode: false
output_dir: "PATH/to/save/PEFT/model"
freeze_layers: false
num_freeze_layers: 1
quantization: false
one_gpu: false
save_model: true
dist_checkpoint_root_folder: "PATH/to/save/FSDP/model" # will be used if using FSDP
dist_checkpoint_folder: "fine-tuned" # will be used if using FSDP
save_optimizer: false # will be used if using FSDP
use_fast_kernels: false # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
run_test_during_validation: false
run_test_during_validation_file: "test.wav"
run_test_during_validation_prompt: "<|ASR|>"
freeze_llm: false
freeze_encoder: false

dataset_config:
dataset: "samsum_dataset"
file: "src/llama_recipes/datasets/speech_dataset.py:get_speech_dataset"
train_data_path: null
val_data_path: null
train_split: "train"
test_split: "validation"
data_path: null
max_words: null
max_mel: null
fix_length_audio: -1

fsdp_config:
mixed_precision: true
use_fp16: false
# sharding_strategy: "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
sharding_strategy: "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD to use DDP mode in FSDP
checkpoint_type: "StateDictType.SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
fsdp_activation_checkpointing: true
fsdp_cpu_offload: false
pure_bf16: false
optimizer: "AdamW"

log_config:
use_wandb: false
wandb_dir: "/root/test_wandb"
wandb_entity_name : "project_name"
wandb_project_name : "project_name"
wandb_exp_name : "exp_name"
log_file: "/root/test.log"
log_interval: 5

135 changes: 72 additions & 63 deletions scripts/finetune_asr_vicuna.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,79 +22,88 @@ output_dir=/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-lr1e-

# -m debugpy --listen 5678 --wait-for-client
if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then
python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/finetune.py \
--model_name asr \
--freeze_encoder \
--freeze_llm \
--llm_name vicuna-13b-v1.5 \
--llm_path $llm_path \
--llm_dim 5120 \
--encoder_name whisper \
--encoder_ds_rate 2 \
--encoder_path $speech_encoder_path \
--encoder_dim 1280 \
--encoder_projector linear \
--encoder_projector_ds_rate 5 \
--dataset speech_dataset \
--speech_dataset.train_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_train_960h.jsonl \
--speech_dataset.val_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_dev_other_filtered.jsonl \
--batching_strategy custom \
--num_epochs 100 \
--batch_size_training 4 \
--val_batch_size 4 \
--num_workers_dataloader 4 \
--lr 1e-4 \
--output_dir $output_dir \
--metric acc \
# --log_file $output_dir/test.log \
# --use_wandb \
# --wandb_dir $output_dir \
# --wandb_entity_name zym22 \
# --wandb_project_name slam-llm \
# --wandb_exp_name test \
# --log_interval 5 \
python src/llama_recipes/pipeline/finetune.py \
--config-path "/root/SLAM-LLM/scripts/conf" \
--config-name "asr_vicuna_lora.yaml" \
++model_config.llm_name="vicuna-7b-v1.5" \
++model_config.llm_path=$llm_path \
++model_config.llm_dim=4096 \
++model_config.encoder_name=whisper \
++model_config.encoder_ds_rate=2 \
++model_config.encoder_path=$speech_encoder_path \
++model_config.encoder_dim=1280 \
++model_config.encoder_projector=linear \
++model_config.encoder_projector_ds_rate=5 \
++dataset_config.dataset=speech_dataset \
++dataset_config.train_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_train_960h.jsonl \
++dataset_config.val_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_dev_other_filtered.jsonl \
++train_config.model_name=asr \
++train_config.freeze_encoder=true \
++train_config.freeze_llm=true \
++train_config.batching_strategy=custom \
++train_config.num_epochs=100 \
++train_config.batch_size_training=4 \
++train_config.val_batch_size=4 \
++train_config.num_workers_dataloader=4 \
++train_config.lr=1e-4 \
++train_config.output_dir=$output_dir \
++train_config.peft_config.peft_method=lora \
++metric=acc \
#++log_config.log_file=/$output_dir/train.log \
#++log_config.use_wandb=true \
#++log_config.wandb_dir=$output_dir \
#++log_config.wandb_entity_name=zym22 \
#++log_config.wandb_project_name=slam-llm \
#++log_config.wandb_exp_name=test \
#++log_config.log_interval 5 \
# --ckpt_path "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-lora-prompt/asr/5/model.pt" \
# --peft_ckpt "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-lora-prompt/asr/5" \
# --use_peft --peft_method lora \

##vicuna-7b-v1.5
else
torchrun \
--nnodes 1 \
--nproc_per_node 4 \
--master_port=29502 \
src/llama_recipes/pipeline/finetune.py \
--model_name asr \
--freeze_encoder \
--freeze_llm \
--use_fp16 \
--enable_fsdp \
--llm_name vicuna-7b-v1.5 \
--llm_path $llm_path \
--llm_dim 4096 \
--encoder_name whisper \
--encoder_ds_rate 2 \
--encoder_path $speech_encoder_path \
--encoder_dim 1280 \
--encoder_projector linear \
--encoder_projector_ds_rate 5 \
--dataset speech_dataset \
--speech_dataset.train_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_train_960h.jsonl \
--speech_dataset.val_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_dev_other_filtered.jsonl \
--batching_strategy custom \
--num_epochs 100 \
--batch_size_training 4 \
--val_batch_size 4 \
--num_workers_dataloader 4 \
--lr 1e-4 \
--output_dir $output_dir \
--metric acc \
--log_file /$output_dir/train.log \
--use_wandb \
--wandb_dir $output_dir \
--wandb_entity_name zym22 \
--wandb_project_name slam-llm \
--wandb_exp_name test \
--log_interval 5 \
--config-path "/root/SLAM-LLM/scripts/conf" \
--config-name "asr_vicuna_lora.yaml" \
++model_config.llm_name="vicuna-7b-v1.5" \
++model_config.llm_path=$llm_path \
++model_config.llm_dim=4096 \
++model_config.encoder_name=whisper \
++model_config.encoder_ds_rate=2 \
++model_config.encoder_path=$speech_encoder_path \
++model_config.encoder_dim=1280 \
++model_config.encoder_projector=linear \
++model_config.encoder_projector_ds_rate=5 \
++dataset_config.dataset=speech_dataset \
++dataset_config.train_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_train_960h.jsonl \
++dataset_config.val_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_dev_other_filtered.jsonl \
++train_config.model_name=asr \
++train_config.freeze_encoder=true \
++train_config.freeze_llm=true \
++train_config.batching_strategy=custom \
++train_config.num_epochs=100 \
++train_config.batch_size_training=4 \
++train_config.val_batch_size=4 \
++train_config.num_workers_dataloader=4 \
++train_config.lr=1e-4 \
++train_config.output_dir=$output_dir \
++train_config.peft_config.peft_method=lora \
++train_config.enable_fsdp=true \
++train_config.enable_ddp=false \
++train_config.use_fp16=true \
++metric=acc \
#++log_config.log_file=/$output_dir/train.log \
#++log_config.use_wandb=true \
#++log_config.wandb_dir=$output_dir \
#++log_config.wandb_entity_name=zym22 \
#++log_config.wandb_project_name=slam-llm \
#++log_config.wandb_exp_name=test \
#++log_config.log_interval 5 \

# --peft_ckpt "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-prompt-padding30-20231228/asr/4" \
# --ckpt_path "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-prompt-padding30-20231228/asr/4/model.pt" \
# --use_peft --peft_method lora \
Expand Down
8 changes: 0 additions & 8 deletions src/llama_recipes/configs/__init__.py

This file was deleted.

79 changes: 0 additions & 79 deletions src/llama_recipes/configs/datasets.py

This file was deleted.

20 changes: 0 additions & 20 deletions src/llama_recipes/configs/fsdp.py

This file was deleted.

Loading

0 comments on commit 73ea14d

Please sign in to comment.