From 3464ab3e5471bc33d3598b698d83f8ab483fa2ee Mon Sep 17 00:00:00 2001 From: ddlBoJack Date: Thu, 25 Jan 2024 23:50:06 +0800 Subject: [PATCH 1/2] update step lr scheduler and validation statistics --- scripts/compute_wer.sh | 4 +- scripts/finetune_asr_vicuna.sh | 19 +- scripts/inference_asr_batch.sh | 26 +- src/llama_recipes/configs/datasets.py | 1 + src/llama_recipes/configs/training.py | 3 + src/llama_recipes/datasets/speech_dataset.py | 22 +- .../datasets/speech_dataset_inference.py | 18 +- src/llama_recipes/models/slam_model.py | 15 +- src/llama_recipes/pipeline/finetune.py | 10 +- src/llama_recipes/pipeline/inference_batch.py | 3 +- src/llama_recipes/utils/compute_ppl.py | 5 +- src/llama_recipes/utils/train_utils.py | 285 ++++++++++-------- 12 files changed, 255 insertions(+), 156 deletions(-) diff --git a/scripts/compute_wer.sh b/scripts/compute_wer.sh index 9bf4c0bc..849ef210 100644 --- a/scripts/compute_wer.sh +++ b/scripts/compute_wer.sh @@ -1,7 +1,7 @@ #cd /root/SLAM-LLM -trans="/nfs/maziyang.mzy/exps/TinyLlama-1.1B-Chat-v0.4-finetune-asr-ds5-proj2048-lr1e-4-freeze-whisper-large-v2-prompt-padding30-20240115/asr/2/decode_log_test_other_beam4_repetition_penalty1_gt" -preds="/nfs/maziyang.mzy/exps/TinyLlama-1.1B-Chat-v0.4-finetune-asr-ds5-proj2048-lr1e-4-freeze-whisper-large-v2-prompt-padding30-20240115/asr/2/decode_log_test_other_beam4_repetition_penalty1_pred" +trans="/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-steplr-whisper-largev2-prompt-lowergt-padding30-20240124/asr/3/decode_log_test_clean_beam4_repetition_penalty1_gt" +preds="/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-steplr-whisper-largev2-prompt-lowergt-padding30-20240124/asr/3/decode_log_test_clean_beam4_repetition_penalty1_pred" # python src/llama_recipes/utils/preprocess_text.py ${preds} ${preds}.proc # python src/llama_recipes/utils/compute_wer.py ${trans} ${preds}.proc ${preds}.proc.wer diff --git a/scripts/finetune_asr_vicuna.sh b/scripts/finetune_asr_vicuna.sh index 6e3b52b2..a48dc7d5 100644 --- a/scripts/finetune_asr_vicuna.sh +++ b/scripts/finetune_asr_vicuna.sh @@ -1,7 +1,8 @@ #!/bin/bash # export PYTHONPATH=/root/whisper:$PYTHONPATH export PYTHONPATH=/root/fairseq:$PYTHONPATH -export CUDA_VISIBLE_DEVICES=2,3,4,5 +export CUDA_VISIBLE_DEVICES=0,1 +export TOKENIZERS_PARALLELISM=false # export CUDA_LAUNCH_BLOCKING=1 export OMP_NUM_THREADS=1 @@ -12,13 +13,17 @@ export OMP_NUM_THREADS=1 cd /root/SLAM-LLM -# speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/large-v2.pt -speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt +# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/tiny.pt +# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/base.pt +# speech_encoder_path=//nfs/maziyang.mzy/models/Whisper/small.pt +# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/medium.pt +speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2.pt +# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5 # llm_path=/nfs/maziyang.mzy/models/vicuna-13b-v1.5 -output_dir=/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-lr1e-4-qwen-prompt-padding30-20240113 +output_dir=/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-lr1e-4-whisper-base-prompt-padding30-20240117 # -m debugpy --listen 5678 --wait-for-client if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then @@ -28,7 +33,7 @@ python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/fin --freeze_llm \ --llm_name vicuna-13b-v1.5 \ --llm_path $llm_path \ ---llm_dim 5120 \ +--llm_dim 4096 \ --encoder_name whisper \ --encoder_ds_rate 2 \ --encoder_path $speech_encoder_path \ @@ -60,8 +65,7 @@ python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/fin else torchrun \ --nnodes 1 \ ---nproc_per_node 4 \ ---master_port=29502 \ +--nproc_per_node 2 \ src/llama_recipes/pipeline/finetune.py \ --model_name asr \ --freeze_encoder \ @@ -98,6 +102,7 @@ src/llama_recipes/pipeline/finetune.py \ # --peft_ckpt "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-prompt-padding30-20231228/asr/4" \ # --ckpt_path "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-prompt-padding30-20231228/asr/4/model.pt" \ # --use_peft --peft_method lora \ +# --master_port=29501 \ fi # {"key": "1001-134707-0000_ASR", "prompt": "", "source": "/cpfs01/shared/Group-speech/beinian.lzr/data/open_data/librispeech_audio/audio/se_librispeech_1001-134707-0000.wav", "target": "1 little recks the laborer. How near his work is holding him to God, The loving laborer through space and time, after all, not to create, only or found only.", "target_len": 157, "source_len": 1581, "text-type": "Transcribe", "audio_language": "en", "text_language": "en", "task-type": ""} diff --git a/scripts/inference_asr_batch.sh b/scripts/inference_asr_batch.sh index 32fd8811..cd608405 100644 --- a/scripts/inference_asr_batch.sh +++ b/scripts/inference_asr_batch.sh @@ -6,27 +6,31 @@ export TOKENIZERS_PARALLELISM=false cd /root/SLAM-LLM -speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/large-v2.pt +# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/tiny.pt +# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/base.pt +# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/small.pt +# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/medium.pt +speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2.pt # speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt +# llm_path=/nfs/maziyang.mzy/models/TinyLlama-1.1B-intermediate-step-1431k-3T +# llm_path=/nfs/maziyang.mzy/models/TinyLlama-1.1B-Chat-v0.4 # llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf -# llm_path=/nfs/maziyang.mzy/models/Llama-2-7b-chat-hf +llm_path=/nfs/maziyang.mzy/models/Llama-2-7b-chat-hf # llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5 -llm_path=/nfs/maziyang.mzy/models/TinyLlama-1.1B-Chat-v0.4 -output_dir=/nfs/maziyang.mzy/exps/TinyLlama-1.1B-Chat-v0.4-finetune-asr-ds5-proj2048-lr1e-4-freeze-whisper-large-v2-prompt-padding30-20240115 -ckpt_path=$output_dir/asr/2 +output_dir=/nfs/maziyang.mzy/exps/Llama-2-7b-chat-hf-finetune-asr-ds5-proj2048-steplr-whisper-largev2-prompt-lowergt-padding30-20240124 +ckpt_path=$output_dir/asr/4 # peft_ckpt=/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-4-whisper-lora-prompt-paddinglr-20240102/asr/4 -val_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_test_other_filtered.jsonl -decode_log=$ckpt_path/decode_log_test_other_beam4_repetition_penalty1 +val_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_test_clean_filtered.jsonl +decode_log=$ckpt_path/decode_log_test_clean_beam4_repetition_penalty1 # -m debugpy --listen 5678 --wait-for-client python src/llama_recipes/pipeline/inference_batch.py \ --model_name asr \ ---freeze_encoder \ ---llm_name tinyllama-1.1b-chat-v0.4 \ +--llm_name Llama-2-7b-chat-hf \ --llm_path $llm_path \ ---llm_dim 2048 \ +--llm_dim 4096 \ --encoder_name whisper \ --encoder_ds_rate 2 \ --encoder_path $speech_encoder_path \ @@ -44,5 +48,7 @@ python src/llama_recipes/pipeline/inference_batch.py \ --ckpt_path $ckpt_path/model.pt \ --decode_log $decode_log \ --freeze_llm \ +--freeze_encoder \ +# --speech_dataset.prompt "Transcribe speech to text." \ # --peft_ckpt $peft_ckpt \ # --use_peft --peft_method lora \ \ No newline at end of file diff --git a/src/llama_recipes/configs/datasets.py b/src/llama_recipes/configs/datasets.py index e8a6a05e..ceaaf928 100644 --- a/src/llama_recipes/configs/datasets.py +++ b/src/llama_recipes/configs/datasets.py @@ -39,6 +39,7 @@ class speech_dataset: max_words: int = None max_mel: int = None fix_length_audio: int = -1 + prompt: str = None @dataclass diff --git a/src/llama_recipes/configs/training.py b/src/llama_recipes/configs/training.py index 1239c7f5..356d3058 100644 --- a/src/llama_recipes/configs/training.py +++ b/src/llama_recipes/configs/training.py @@ -10,12 +10,15 @@ class train_config: enable_fsdp: bool=False low_cpu_fsdp: bool=False run_validation: bool=True + validation_interval: int=1000 batch_size_training: int=4 batching_strategy: str="packing" #alternative: padding context_length: int=4096 gradient_accumulation_steps: int=1 num_epochs: int=3 num_workers_dataloader: int=1 + warmup_steps: int=1000 + total_steps: int=50000 lr: float=1e-4 weight_decay: float=0.0 gamma: float= 0.85 diff --git a/src/llama_recipes/datasets/speech_dataset.py b/src/llama_recipes/datasets/speech_dataset.py index 5a51d330..437201c1 100644 --- a/src/llama_recipes/datasets/speech_dataset.py +++ b/src/llama_recipes/datasets/speech_dataset.py @@ -29,6 +29,19 @@ def __init__(self, # self.data_list = contents self.IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss + self.prompt = dataset_config.prompt + # self.prompt_library = [ + # "Begin by converting the spoken words into written text. ", + # "Can you transcribe the speech into a written format? ", + # "Focus on translating the audible content into text. ", + # "Transcribe the speech by carefully listening to it. ", + # "Would you kindly write down the content of the speech? ", + # "Analyze the speech and create a written transcription. ", + # "Engage with the speech to produce a text-based version. ", + # "Can you document the speech in written form? ", + # "Transform the spoken words into text accurately. ", + # "How about putting the speech's content into writing? " + # ] self.prompt_template = "USER: {}\n ASSISTANT:" self.answer_template = "{}" self.fix_length_audio = dataset_config.fix_length_audio @@ -76,10 +89,13 @@ def __getitem__(self, index): # audio_raw = np.concatenate((np.zeros(random.randint(0, 16000)), audio_raw, np.zeros(random.randint(0, 16000)))).astype(audio_raw.dtype)[:16000*30] audio_mel = whisper.log_mel_spectrogram(audio_raw).permute(1, 0) - prompt = "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. " - + prompt = self.prompt + if prompt is None: + # prompt = random.choice(self.prompt_library) + prompt = "Transcribe speech to text. " + # prompt = "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. " prompt = self.prompt_template.format(prompt) - answer = self.answer_template.format(target) + answer = self.answer_template.format(target.lower()) prompt_ids = self.tokenizer.encode(prompt) diff --git a/src/llama_recipes/datasets/speech_dataset_inference.py b/src/llama_recipes/datasets/speech_dataset_inference.py index 1a1155dc..8b6ecb9e 100644 --- a/src/llama_recipes/datasets/speech_dataset_inference.py +++ b/src/llama_recipes/datasets/speech_dataset_inference.py @@ -29,6 +29,19 @@ def __init__(self, # self.data_list = contents self.IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss + self.prompt = dataset_config.prompt + # self.prompt_library = [ + # "Begin by converting the spoken words into written text. ", + # "Can you transcribe the speech into a written format? ", + # "Focus on translating the audible content into text. ", + # "Transcribe the speech by carefully listening to it. ", + # "Would you kindly write down the content of the speech? ", + # "Analyze the speech and create a written transcription. ", + # "Engage with the speech to produce a text-based version. ", + # "Can you document the speech in written form? ", + # "Transform the spoken words into text accurately. ", + # "How about putting the speech's content into writing? " + # ] self.prompt_template = "USER: {}\n ASSISTANT:" self.fix_length_audio = dataset_config.fix_length_audio @@ -72,7 +85,10 @@ def __getitem__(self, index): # audio_raw = np.concatenate((np.zeros(random.randint(0, 16000)), audio_raw, np.zeros(random.randint(0, 16000)))).astype(audio_raw.dtype)[:16000*30] audio_mel = whisper.log_mel_spectrogram(audio_raw).permute(1, 0) - prompt = "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. " + prompt = self.prompt + if prompt is None: + # prompt = random.choice(self.prompt_library) + prompt = "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. " prompt = self.prompt_template.format(prompt) prompt_ids = self.tokenizer.encode(prompt) diff --git a/src/llama_recipes/models/slam_model.py b/src/llama_recipes/models/slam_model.py index ad3d1e8a..5214cc52 100644 --- a/src/llama_recipes/models/slam_model.py +++ b/src/llama_recipes/models/slam_model.py @@ -25,14 +25,15 @@ def setup_model(tokenizer, train_config, model_config, **kwargs): def setup_tokenizer(train_config, model_config, **kwargs): # Load the tokenizer and add special tokens - if "llama" in model_config.llm_name or "vicuna" in model_config.llm_name: - tokenizer = AutoTokenizer.from_pretrained(model_config.llm_path) - tokenizer.pad_token_id = tokenizer.eos_token_id - return tokenizer + tokenizer = AutoTokenizer.from_pretrained(model_config.llm_path) + tokenizer.pad_token_id = tokenizer.eos_token_id + return tokenizer def setup_encoder(train_config, model_config, **kwargs): - encoder_list = model_config.encoder_name.split(",") + encoder_list = model_config.encoder_name.split(",") if model_config.encoder_name else [] + if len(encoder_list) == 0: + return None if len(encoder_list) == 1: encoder_name = encoder_list[0] if encoder_name == "whisper" or encoder_name == "qwen-audio": @@ -198,6 +199,8 @@ def forward(self, encoder_outs, audio_mel_post_mask = self.encoder.extract_features(audio_mel, audio_mel_mask) # bs*seq*dim if self.model_config.encoder_name == "moco_wav2vec2": encoder_outs , inputLenBatch, audio_mel_post_mask = self.encoder((audio, audio_mask, visual, vis_len) ,maskw2v) # bs*seq*dim + if self.encoder is None: + encoder_outs = audio_mel if audio_mel is not None else audio if self.model_config.encoder_projector == "q-former": encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask) @@ -309,7 +312,7 @@ def inference( negative_prompt_ids = None, negative_prompt_attention_mask = None, **kwargs, - ): # TODO: Now you need to set your customized sampling rate manually + ): device = kwargs.get("device", "cuda") if os.path.exists(wav_path): # Audio-Text QA diff --git a/src/llama_recipes/pipeline/finetune.py b/src/llama_recipes/pipeline/finetune.py index 6ce5a71c..b038f4c0 100644 --- a/src/llama_recipes/pipeline/finetune.py +++ b/src/llama_recipes/pipeline/finetune.py @@ -205,7 +205,15 @@ def main(**kwargs): lr=train_config.lr, weight_decay=train_config.weight_decay, ) - scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma) + # scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma) + scheduler = torch.optim.lr_scheduler.LambdaLR( + optimizer, + lr_lambda=lambda step: ( + min(step / train_config.warmup_steps, 1) if step < train_config.warmup_steps + else 1 + # else max(0.0, 1 - (step - train_config.warmup_steps) / (train_config.total_steps - train_config.warmup_steps)) + ) + ) # Start the training process results = train( diff --git a/src/llama_recipes/pipeline/inference_batch.py b/src/llama_recipes/pipeline/inference_batch.py index c413de86..7e8a7309 100644 --- a/src/llama_recipes/pipeline/inference_batch.py +++ b/src/llama_recipes/pipeline/inference_batch.py @@ -18,6 +18,7 @@ from llama_recipes.utils.dataset_utils import get_preprocessed_dataset import os import logging +from tqdm import tqdm def main(**kwargs): @@ -89,7 +90,7 @@ def main(**kwargs): pred_path = kwargs.get('decode_log') + "_pred" gt_path = kwargs.get('decode_log') + "_gt" with open(pred_path, "w") as pred, open(gt_path, "w") as gt: - for step, batch in enumerate(test_dataloader): + for step, batch in tqdm(enumerate(test_dataloader), total=len(test_dataloader)): for key in batch.keys(): batch[key] = batch[key].to(device) if key not in ["keys", "targets"] else batch[key] model_outputs = model.generate(**batch) diff --git a/src/llama_recipes/utils/compute_ppl.py b/src/llama_recipes/utils/compute_ppl.py index 55baa292..a5521f3c 100644 --- a/src/llama_recipes/utils/compute_ppl.py +++ b/src/llama_recipes/utils/compute_ppl.py @@ -7,7 +7,7 @@ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) model = AutoModelForCausalLM.from_pretrained(MODEL_PATH) -device = 'cuda:6' +device = 'cuda:0' model.to(device) model.eval() @@ -25,7 +25,8 @@ inputs = tokenizer(sentence, return_tensors="pt").to(device) input_ids = inputs["input_ids"] - input_len = input_ids.size(1) + # input_len = input_ids.size(1) + input_len = len(sentence.split(" ")) total_tokens += input_len with torch.no_grad(): diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py index 1fa73123..45739cb5 100644 --- a/src/llama_recipes/utils/train_utils.py +++ b/src/llama_recipes/utils/train_utils.py @@ -81,7 +81,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche checkpoint_times = [] results = {} best_val_loss = float("inf") - best_val_acc = float("inf") + best_val_acc = 0.0 for epoch in range(train_config.num_epochs): epoch_start_time = time.perf_counter() with MemoryTrace() as memtrace: # track the memory usage @@ -121,6 +121,19 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: scaler.step(optimizer) scaler.update() + if lr_scheduler is not None: + lr_scheduler.step() + current_lr = lr_scheduler.get_last_lr()[0] + else: + current_lr = optimizer.param_groups[0]["lr"] + if current_lr == 0: + break + if log_config.use_wandb and step % log_config.log_interval == 0: + if train_config.enable_fsdp: + if rank==0: + wandb.log({"train_inner/lr":current_lr}, step=(epoch * total_length + step)) + else: + wandb.log({"train_inner/lr":current_lr}, step=(epoch * total_length + step)) optimizer.zero_grad() pbar.update(1) else: @@ -128,10 +141,146 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche loss.backward() if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: optimizer.step() + if lr_scheduler is not None: + lr_scheduler.step() + current_lr = lr_scheduler.get_last_lr()[0] + else: + current_lr = optimizer.param_groups[0]["lr"] + if current_lr == 0: + break + if log_config.use_wandb and step % log_config.log_interval == 0: + if train_config.enable_fsdp: + if rank==0: + wandb.log({"train_inner/lr":current_lr}, step=(epoch * total_length + step)) + else: + wandb.log({"train_inner/lr":current_lr}, step=(epoch * total_length + step)) optimizer.zero_grad() pbar.update(1) pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()}, acc: {acc})") + + if (epoch * total_length + step) % train_config.validation_interval == 0 and train_config.run_validation: + eval_ppl, eval_epoch_loss, *rest = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer) + eval_epoch_acc = rest[0] if rest else -1 + checkpoint_start_time = time.perf_counter() + if train_config.save_model and (eval_epoch_loss < best_val_loss): + if train_config.enable_fsdp: + dist.barrier() + if train_config.use_peft: + if train_config.enable_fsdp: + if rank==0: + logger.info(f"we are about to save the PEFT modules") + else: + logger.info(f"we are about to save the PEFT modules") + if train_config.enable_fsdp: #(FIX:MZY):We now only support full_shard and no_shard. + if fsdp_config.sharding_strategy == ShardingStrategy.FULL_SHARD: + save_model_checkpoint_peft_full_shard( + model, optimizer, rank, train_config, epoch=epoch + ) + elif fsdp_config.sharding_strategy == ShardingStrategy.NO_SHARD: + if rank==0: + save_model_checkpoint_peft( + model, optimizer, rank, train_config, epoch=epoch + ) + dist.barrier() + else: + # model.save_pretrained(train_config.output_dir) + save_model_checkpoint_peft( + model, optimizer, rank, train_config, epoch=epoch + ) + if train_config.enable_fsdp: + if rank==0: + logger.info(f"PEFT modules are saved in {train_config.output_dir} directory") + else: + logger.info(f"PEFT modules are saved in {train_config.output_dir} directory") + + elif not train_config.use_peft and train_config.freeze_llm: + logger.info(f"llm is frozen, we are about to save other parts.") + if train_config.enable_fsdp: #(FIX:MZY):We now only support full_shard and no_shard. + if fsdp_config.sharding_strategy == ShardingStrategy.FULL_SHARD: + save_model_checkpoint_peft_full_shard( + model, optimizer, rank, train_config, epoch=epoch + ) + elif fsdp_config.sharding_strategy == ShardingStrategy.NO_SHARD: + if rank==0: + save_model_checkpoint_peft( + model, optimizer, rank, train_config, epoch=epoch + ) + dist.barrier() + else: + save_model_checkpoint_peft( + model, optimizer, rank, train_config, epoch=epoch + ) + + else: + if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT: + + save_model_checkpoint( + model, optimizer, rank, train_config, epoch=epoch + ) + elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT: + logger.info(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT") + logger.info("=====================================================") + + save_model_and_optimizer_sharded(model, rank, train_config) + if train_config.save_optimizer: + save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer) + logger.info(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT") + logger.info("=====================================================") + + if not train_config.use_peft and train_config.save_optimizer: + save_optimizer_checkpoint( + model, optimizer, rank, train_config, epoch=epoch + ) + logger.info(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT") + logger.info("=====================================================") + if train_config.enable_fsdp: + dist.barrier() + checkpoint_end_time = time.perf_counter() - checkpoint_start_time + checkpoint_times.append(checkpoint_end_time) + if eval_epoch_loss < best_val_loss: + best_val_loss = eval_epoch_loss + if train_config.enable_fsdp: + if rank==0: + logger.info(f"best eval loss on epoch {epoch+1} is {best_val_loss}") + else: + logger.info(f"best eval loss on epoch {epoch+1} is {best_val_loss}") + val_loss.append(eval_epoch_loss) + val_prep.append(eval_ppl) + if rest: + if eval_epoch_acc > best_val_acc: + best_val_acc = eval_epoch_acc + if train_config.enable_fsdp: + if rank==0: + logger.info(f"best eval acc on epoch {epoch+1} is {best_val_acc}") + else: + logger.info(f"best eval acc on epoch {epoch+1} is {best_val_acc}") + val_acc.append(rest[0]) + else: + val_acc.append(-1) + + if log_config.use_wandb: + if train_config.enable_fsdp: + if rank==0: + wandb.log({"valid/val_epoch_loss":eval_epoch_loss, "valid/val_perplexity":eval_ppl, "valid/best_val_loss":best_val_loss, "valid/val_accuracy":val_acc[-1], "valid/val_best_accuracy":best_val_acc}) + else: + wandb.log({"valid/val_epoch_loss":eval_epoch_loss, "valid/val_perplexity":eval_ppl, "valid/best_val_loss":best_val_loss, "valid/val_accuracy":val_acc[-1], "valid/val_best_accuracy":best_val_acc}) + + if train_config.run_test_during_validation: + if train_config.enable_fsdp: + if rank==0: + logger.info("=====================================") + logger.info(f"Test the file {train_config.run_test_during_validation_file} during validation:") + with autocast(): + logger.info(model.inference(train_config.run_test_during_validation_file, train_config.run_test_during_validation_prompt)) + logger.info("=====================================") + dist.barrier() + else: + logger.info("=====================================") + logger.info(f"Test the file {train_config.run_test_during_validation_file} during validation:") + with autocast(): + logger.info(model.inference(train_config.run_test_during_validation_file, train_config.run_test_during_validation_prompt)) + logger.info("=====================================") pbar.close() epoch_end_time = time.perf_counter()-epoch_start_time @@ -158,6 +307,12 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche else: wandb.log({"train/train_perplexity":train_perplexity, "train/train_epoch_loss":train_epoch_loss, "train/train_epoch_acc":train_epoch_acc}) + if train_config.enable_fsdp: + if rank==0: + logger.info(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s") + else: + logger.info(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s") + if train_config.enable_fsdp: if rank==0: logger.info(f"Max CUDA memory allocated was {memtrace.peak} GB") @@ -173,128 +328,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche logger.info(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB") # Update the learning rate as needed - lr_scheduler.step() - - if train_config.run_validation: - eval_ppl, eval_epoch_loss, *rest = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer) - eval_epoch_acc = rest[0] if rest else -1 - checkpoint_start_time = time.perf_counter() - if train_config.save_model and (eval_epoch_loss < best_val_loss or eval_epoch_acc > best_val_acc): - if train_config.enable_fsdp: - dist.barrier() - if train_config.use_peft: - if train_config.enable_fsdp: - if rank==0: - logger.info(f"we are about to save the PEFT modules") - else: - logger.info(f"we are about to save the PEFT modules") - if train_config.enable_fsdp: #(FIX:MZY):We now only support full_shard and no_shard. - if fsdp_config.sharding_strategy == ShardingStrategy.FULL_SHARD: - save_model_checkpoint_peft_full_shard( - model, optimizer, rank, train_config, epoch=epoch - ) - elif fsdp_config.sharding_strategy == ShardingStrategy.NO_SHARD: - if rank==0: - save_model_checkpoint_peft( - model, optimizer, rank, train_config, epoch=epoch - ) - dist.barrier() - else: - # model.save_pretrained(train_config.output_dir) - save_model_checkpoint_peft( - model, optimizer, rank, train_config, epoch=epoch - ) - if train_config.enable_fsdp: - if rank==0: - logger.info(f"PEFT modules are saved in {train_config.output_dir} directory") - else: - logger.info(f"PEFT modules are saved in {train_config.output_dir} directory") - - elif not train_config.use_peft and train_config.freeze_llm: - logger.info(f"llm is frozen, we are about to save other parts.") - if train_config.enable_fsdp: #(FIX:MZY):We now only support full_shard and no_shard. - if fsdp_config.sharding_strategy == ShardingStrategy.FULL_SHARD: - save_model_checkpoint_peft_full_shard( - model, optimizer, rank, train_config, epoch=epoch - ) - elif fsdp_config.sharding_strategy == ShardingStrategy.NO_SHARD: - if rank==0: - save_model_checkpoint_peft( - model, optimizer, rank, train_config, epoch=epoch - ) - dist.barrier() - else: - save_model_checkpoint_peft( - model, optimizer, rank, train_config, epoch=epoch - ) + # lr_scheduler.step() - else: - if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT: - - save_model_checkpoint( - model, optimizer, rank, train_config, epoch=epoch - ) - elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT: - logger.info(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT") - logger.info("=====================================================") - - save_model_and_optimizer_sharded(model, rank, train_config) - if train_config.save_optimizer: - save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer) - logger.info(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT") - logger.info("=====================================================") - - if not train_config.use_peft and train_config.save_optimizer: - save_optimizer_checkpoint( - model, optimizer, rank, train_config, epoch=epoch - ) - logger.info(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT") - logger.info("=====================================================") - if train_config.enable_fsdp: - dist.barrier() - checkpoint_end_time = time.perf_counter() - checkpoint_start_time - checkpoint_times.append(checkpoint_end_time) - if eval_epoch_loss < best_val_loss: - best_val_loss = eval_epoch_loss - if train_config.enable_fsdp: - if rank==0: - logger.info(f"best eval loss on epoch {epoch+1} is {best_val_loss}") - else: - logger.info(f"best eval loss on epoch {epoch+1} is {best_val_loss}") - val_loss.append(eval_epoch_loss) - val_prep.append(eval_ppl) - if rest: - val_acc.append(rest[0]) - else: - val_acc.append(-1) - - if log_config.use_wandb: - if train_config.enable_fsdp: - if rank==0: - wandb.log({"valid/val_epoch_loss":eval_epoch_loss, "valid/val_perplexity":eval_ppl, "valid/best_val_loss":best_val_loss, "valid/val_accuracy":val_acc[-1]}) - else: - wandb.log({"valid/val_epoch_loss":eval_epoch_loss, "valid/val_perplexity":eval_ppl, "valid/best_val_loss":best_val_loss, "valid/val_accuracy":val_acc[-1]}) - - if train_config.run_test_during_validation: - if train_config.enable_fsdp: - if rank==0: - logger.info("=====================================") - logger.info(f"Test the file {train_config.run_test_during_validation_file} during validation:") - with autocast(): - logger.info(model.inference(train_config.run_test_during_validation_file, train_config.run_test_during_validation_prompt)) - logger.info("=====================================") - dist.barrier() - else: - logger.info("=====================================") - logger.info(f"Test the file {train_config.run_test_during_validation_file} during validation:") - with autocast(): - logger.info(model.inference(train_config.run_test_during_validation_file, train_config.run_test_during_validation_prompt)) - logger.info("=====================================") - if train_config.enable_fsdp: - if rank==0: - logger.info(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s") - else: - logger.info(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s") avg_epoch_time = sum(epoch_times)/ len(epoch_times) avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0 avg_train_prep = sum(train_prep)/len(train_prep) @@ -342,7 +377,9 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer): autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext # (Fix:MZY): fix expected scalar type mismatch in norm with MemoryTrace() as memtrace: - for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)): + total_length = len(eval_dataloader) + pbar = tqdm(colour="green", desc=f"Evaluating Epoch", total=total_length, dynamic_ncols=True) + for step, batch in enumerate(eval_dataloader): for key in batch.keys(): if type(batch[key])==bool: continue @@ -365,6 +402,8 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer): eval_preds.extend( tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True) ) + pbar.update(1) + pbar.set_description(f"step: {step+1}/{total_length}, eval_loss: {eval_loss/(step+1):.4f}, eval_acc: {eval_acc/(step+1):.4f}") # If there's more than one CUDA device, reduce evaluation loss across all devices if torch.cuda.device_count() > 1 and train_config.enable_fsdp: From de6d84e52741a69e0a34260cd6a9e272a353f18a Mon Sep 17 00:00:00 2001 From: ddlBoJack Date: Fri, 26 Jan 2024 03:03:24 +0800 Subject: [PATCH 2/2] delate config --- src/llama_recipes/configs/datasets.py | 80 --------------------------- src/llama_recipes/configs/training.py | 46 --------------- 2 files changed, 126 deletions(-) delete mode 100644 src/llama_recipes/configs/datasets.py delete mode 100644 src/llama_recipes/configs/training.py diff --git a/src/llama_recipes/configs/datasets.py b/src/llama_recipes/configs/datasets.py deleted file mode 100644 index ceaaf928..00000000 --- a/src/llama_recipes/configs/datasets.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -from dataclasses import dataclass - - -@dataclass -class samsum_dataset: - dataset: str = "samsum_dataset" - train_split: str = "train" - test_split: str = "validation" - - -@dataclass -class grammar_dataset: - dataset: str = "grammar_dataset" - train_split: str = "src/llama_recipes/datasets/grammar_dataset/gtrain_10k.csv" - test_split: str = "src/llama_recipes/datasets/grammar_dataset/grammar_validation.csv" - - -@dataclass -class alpaca_dataset: - dataset: str = "alpaca_dataset" - train_split: str = "train" - test_split: str = "val" - data_path: str = "src/llama_recipes/datasets/alpaca_data.json" - - -@dataclass -class speech_dataset: - dataset: str = "speech_dataset" - file: str = "src/llama_recipes/datasets/speech_dataset.py:get_speech_dataset" - train_split: str = "train" - test_split: str = "validation" - data_path: str = None - max_words: int = None - train_data_path: str = None - val_data_path: str = None - max_words: int = None - max_mel: int = None - fix_length_audio: int = -1 - prompt: str = None - - -@dataclass -class audio_dataset: - dataset: str = "audio_dataset" - file: str = "src/llama_recipes/datasets/audio_dataset.py:get_audio_dataset" - train_split: str = "train" - test_split: str = "validation" - data_path: str = None - fbank_mean: float = 15.41663 - fbank_std: float = 6.55582 - max_words: int = None - train_data_path: str = None - val_data_path: str = None - max_words: int = None - max_mel: int = None - fix_length_audio: int = -1 - - -@dataclass -class avsr_dataset: - dataset: str = "avsr_dataset" - file: str = "examples/avsr_dataset.py" - train_split: str = "train" - test_split: str = "val" - data_path: str = "/nfs/yangguanrou.ygr/" #"/home/oss/yangguanrou.ygr/" - h5file: str = "/nfs/yangguanrou.ygr/LRS3/LRS3.h5" # "/home/oss/yangguanrou.ygr/LRS3/LRS3.h5" - noiseFile : str = "/nfs/yangguanrou.ygr/AVSR/LRS3/Noise.h5" #"/home/oss/yangguanrou.ygr/AVSR/LRS3/Noise.h5" - noiseProb: float = 0. - noiseSNR: float = 5 - stepSize: int = 16384 - charToIx : str = "x" #应该没用了 TypeError: Object of type NotImplementedType is not JSON serializable 但这个是上面的问题 - modal: str = "AV" - pretrain_subset: str = "LRS3/pretrain.txt" - train_subset: str = "LRS3/train.txt" - valid_subset: str = "LRS3/val.txt" - test_subset: str = "LRS3/test.txt" - reqInpLen: str = 80 diff --git a/src/llama_recipes/configs/training.py b/src/llama_recipes/configs/training.py deleted file mode 100644 index 356d3058..00000000 --- a/src/llama_recipes/configs/training.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -from dataclasses import dataclass - - -@dataclass -class train_config: - model_name: str="PATH/to/LLAMA/7B" - enable_fsdp: bool=False - low_cpu_fsdp: bool=False - run_validation: bool=True - validation_interval: int=1000 - batch_size_training: int=4 - batching_strategy: str="packing" #alternative: padding - context_length: int=4096 - gradient_accumulation_steps: int=1 - num_epochs: int=3 - num_workers_dataloader: int=1 - warmup_steps: int=1000 - total_steps: int=50000 - lr: float=1e-4 - weight_decay: float=0.0 - gamma: float= 0.85 - seed: int=42 - use_fp16: bool=False - mixed_precision: bool=True - val_batch_size: int=1 - dataset = "samsum_dataset" - peft_method: str = "lora" # None , llama_adapter, prefix - use_peft: bool=False - output_dir: str = "PATH/to/save/PEFT/model" - freeze_layers: bool = False - num_freeze_layers: int = 1 - quantization: bool = False - one_gpu: bool = False - save_model: bool = True - dist_checkpoint_root_folder: str="PATH/to/save/FSDP/model" # will be used if using FSDP - dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP - save_optimizer: bool=False # will be used if using FSDP - use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels - run_test_during_validation: bool = False - run_test_during_validation_file: str = "test.wav" - run_test_during_validation_prompt: str = "<|ASR|>" - freeze_llm: bool = False - freeze_encoder: bool = False