From 72f1297841ed73b2789968add02bb7feb4effcfb Mon Sep 17 00:00:00 2001 From: Ziyang Ma Date: Tue, 7 May 2024 18:30:48 +0000 Subject: [PATCH] clean up and update README --- README.md | 7 +- examples/vsr_LRS3/model/slam_model_vsr.py | 80 +----- src/slam_llm/models/AV/av_net.py | 228 ------------------ src/slam_llm/models/AV/avsr_model.py | 122 ---------- .../models/AV/moco_visual_frontend.py | 42 ---- src/slam_llm/models/AV/utils.py | 140 ----------- src/slam_llm/models/AV/visual_encoder.py | 90 ------- src/slam_llm/pipeline/finetune.py | 35 +-- src/slam_llm/pipeline/inference_batch.py | 24 +- src/slam_llm/utils/dataset_utils.py | 2 +- src/slam_llm/utils/model_utils.py | 2 +- 11 files changed, 29 insertions(+), 743 deletions(-) delete mode 100644 src/slam_llm/models/AV/av_net.py delete mode 100644 src/slam_llm/models/AV/avsr_model.py delete mode 100644 src/slam_llm/models/AV/moco_visual_frontend.py delete mode 100644 src/slam_llm/models/AV/utils.py delete mode 100644 src/slam_llm/models/AV/visual_encoder.py diff --git a/README.md b/README.md index 11195570..0ce6e7d5 100644 --- a/README.md +++ b/README.md @@ -27,8 +27,10 @@ developers to train custom multimodal large language model (MLLM), focusing on < 5. [Acknowledge](#acknowledge) # News -- [Update Apr. 28, 2024] Recipes for automated audio captioning (AAC) with SOTA performance has been supported. -- [Update Mar. 31, 2024] Recipes for automatic speech recognition (ASR) with SOTA performance has been supported. +- [Update May. 8, 2024] Recipes for [visual speech recognition (VSR)](examples/vsr_LRS3/README.md) has been supported. +- [Update May. 4, 2024] Recipes for [zero-shot text-to-speech (TTS)](examples/vallex/README.md) has been supported. +- [Update Apr. 28, 2024] Recipes for [automated audio captioning (AAC)](examples/aac_audiocaps/README.md) has been supported. +- [Update Mar. 31, 2024] Recipes for [automatic speech recognition (ASR)](examples/asr_librispeech/README.md) has been supported. # Installation ```bash @@ -61,6 +63,7 @@ We provide reference implementations of various LLM-based speech, audio, and mus - **Speech Task** - [Automatic Speech Recognition (ASR)](examples/asr_librispeech/README.md) - [Text-to-Speech (TTS)](examples/vallex/README.md) + - [Visual Speech Recognition (VSR)](examples/vsr_LRS3/README.md) - **Audio Task** - [Automated Audio Captioning (AAC)](examples/aac_audiocaps/README.md) diff --git a/examples/vsr_LRS3/model/slam_model_vsr.py b/examples/vsr_LRS3/model/slam_model_vsr.py index 0910d2ed..bc7a8bf5 100644 --- a/examples/vsr_LRS3/model/slam_model_vsr.py +++ b/examples/vsr_LRS3/model/slam_model_vsr.py @@ -74,82 +74,4 @@ def __init__( train_config, model_config, **kwargs, - ) - - - @torch.no_grad() - def inference( - self, - wav_path=None, - prompt=None, - generation_config=None, - logits_processor=None, - stopping_criteria=None, - prefix_allowed_tokens_fn=None, - synced_gpus=None, - assistant_model=None, - streamer=None, - negative_prompt_ids=None, - negative_prompt_attention_mask=None, - **kwargs, - ): - # inference for asr model - - device = kwargs.get("device", "cuda") - if os.path.exists(wav_path): # Audio-Text QA - import whisper - - audio_raw = whisper.load_audio(wav_path) - audio_raw = whisper.pad_or_trim(audio_raw) - - mel_size = getattr( - self.dataset_config, "mel_size", 80 - ) # 80 for large v1 and v2, 128 for large v3 - audio_mel = ( - whisper.log_mel_spectrogram(audio_raw, n_mels=mel_size) - .permute(1, 0)[None, :, :] - .to(device) - ) - - encoder_outs = self.encoder.extract_variable_length_features( - audio_mel.permute(0, 2, 1) - ) - - if self.model_config.encoder_projector == "q-former": - audio_mel_post_mask = torch.ones( - encoder_outs.size()[:-1], dtype=torch.long - ).to(encoder_outs.device) - encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask) - if self.model_config.encoder_projector == "linear": - encoder_outs = self.encoder_projector(encoder_outs) - else: # Text QA - encoder_outs = torch.empty( - 1, 0, self.llm.model.embed_tokens.embedding_dim - ).to(device) - - prompt = "USER: {}\n ASSISTANT:".format(prompt) - prompt_ids = self.tokenizer.encode(prompt) - prompt_length = len(prompt_ids) - prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(device) - - if hasattr(self.llm.model, "embed_tokens"): - inputs_embeds = self.llm.model.embed_tokens(prompt_ids) - elif hasattr(self.llm.model.model, "embed_tokens"): - inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids) - else: - inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids) - - inputs_embeds = torch.cat( - (encoder_outs, inputs_embeds[None, :, :]), dim=1 - ) # [audio,prompt] - - attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to( - inputs_embeds.device - ) - - # generate - model_outputs = self.generate( - inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs - ) - - return model_outputs + ) \ No newline at end of file diff --git a/src/slam_llm/models/AV/av_net.py b/src/slam_llm/models/AV/av_net.py deleted file mode 100644 index 8a4fa658..00000000 --- a/src/slam_llm/models/AV/av_net.py +++ /dev/null @@ -1,228 +0,0 @@ -from fairseq.checkpoint_utils import load_model_ensemble_and_task -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.nn.utils.rnn import pad_sequence -from tqdm import tqdm - -from .moco_visual_frontend import MoCoVisualFrontend -from .utils import PositionalEncoding, conv1dLayers, outputConv, MaskedLayerNorm, generate_square_subsequent_mask - -from transformers import TransfoXLTokenizer, TransfoXLLMHeadModel - - -class AVNet(nn.Module): - - def __init__(self, model_config): - super(AVNet, self).__init__() - - self.modal = model_config.modal - self.numClasses = model_config.CHAR_NUM_CLASSES - self.reqInpLen = model_config.MAIN_REQ_INPUT_LENGTH - self.dModel= model_config.DMODEL #!!! - self.nHeads = model_config.TX_ATTENTION_HEADS - self.numLayers = model_config.TX_NUM_LAYERS - self.peMaxLen= model_config.PE_MAX_LENGTH - self.audinSize = model_config.AUDIO_FEATURE_SIZE - self.vidinSize = model_config.VIDEO_FEATURE_SIZE - self.fcHiddenSize = model_config.TX_FEEDFORWARD_DIM - self.dropout = model_config.TX_DROPOUT - self.MoCofile = model_config.MOCO_FRONTEND_FILE - self.W2Vfile = model_config.WAV2VEC_FILE - - # A & V Modal - tx_norm = nn.LayerNorm(self.dModel) - self.maskedLayerNorm = MaskedLayerNorm() - if self.modal == "AV": - self.ModalityNormalization = nn.LayerNorm(self.dModel) - self.EncoderPositionalEncoding = PositionalEncoding(dModel=self.dModel, maxLen=self.peMaxLen) #512,500 - - # audio - if not self.modal == "VO": - # front-end - wav2vecModel, cfg, task = load_model_ensemble_and_task([self.W2Vfile], arg_overrides={ - "apply_mask": True, - "mask_prob": 0.5, - "mask_channel_prob": 0.25, - "mask_channel_length": 64, - "layerdrop": 0.1, - "activation_dropout": 0.1, - "feature_grad_mult": 0.0, - }) - wav2vecModel = wav2vecModel[0] - wav2vecModel.remove_pretraining_modules() - self.wav2vecModel = wav2vecModel - # back-end - self.audioConv = conv1dLayers(self.maskedLayerNorm, self.audinSize, self.dModel, self.dModel, downsample=True) - audioEncoderLayer = nn.TransformerEncoderLayer(d_model=self.dModel, nhead=self.nHeads, dim_feedforward=self.fcHiddenSize, dropout=self.dropout) - self.audioEncoder = nn.TransformerEncoder(audioEncoderLayer, num_layers=self.numLayers, norm=tx_norm) - else: - self.wav2vecModel = None #主要是这三个 - self.audioConv = None - self.audioEncoder = None - - # visual - if not self.modal == "AO": - # front-end - visualModel = MoCoVisualFrontend(model_config) - if self.MoCofile is not None: - visualModel.load_state_dict(torch.load(self.MoCofile, map_location="cpu"), strict=False) - self.visualModel = visualModel - # back-end - self.videoConv = conv1dLayers(self.maskedLayerNorm, self.vidinSize, self.dModel, self.dModel) - videoEncoderLayer = nn.TransformerEncoderLayer(d_model=self.dModel, nhead=self.nHeads, dim_feedforward=self.fcHiddenSize, dropout=self.dropout) - self.videoEncoder = nn.TransformerEncoder(videoEncoderLayer, num_layers=self.numLayers, norm=tx_norm) - else: - self.visualModel = None #主要是这三个 - self.videoConv = None - self.videoEncoder = None - - # JointConv for fusion - if self.modal == "AV": - self.jointConv = conv1dLayers(self.maskedLayerNorm, 2 * self.dModel, self.dModel, self.dModel) - jointEncoderLayer = nn.TransformerEncoderLayer(d_model=self.dModel, nhead=self.nHeads, dim_feedforward=self.fcHiddenSize, dropout=self.dropout) - self.jointEncoder = nn.TransformerEncoder(jointEncoderLayer, num_layers=self.numLayers, norm=tx_norm) - - # self.jointOutputConv = outputConv(self.maskedLayerNorm, self.dModel, self.numClasses) - # self.decoderPositionalEncoding = PositionalEncoding(dModel=self.dModel, maxLen=self.peMaxLen) - # self.embed = torch.nn.Sequential( - # nn.Embedding(self.numClasses, self.dModel), - # self.decoderPositionalEncoding - # ) - # jointDecoderLayer = nn.TransformerDecoderLayer(d_model=self.dModel, nhead=self.nHeads, dim_feedforward=self.fcHiddenSize, dropout=self.dropout) - # self.jointAttentionDecoder = nn.TransformerDecoder(jointDecoderLayer, num_layers=self.numLayers, norm=tx_norm) - # self.jointAttentionOutputConv = outputConv("LN", self.dModel, self.numClasses) - - def forward(self, inputBatch, maskw2v): - audioBatch, audMask, videoBatch, vidLen = inputBatch #torch.Size([2, 32480]),torch.Size([2, 32480]),torch.Size([2, 52, 1, 112, 112]),[52,47] # audMask尾部有一堆true表示mask,其余都是false - if not self.modal == "VO": - try: - result = self.wav2vecModel.extract_features(audioBatch, padding_mask=audMask, mask=maskw2v) #new_version 这一步/320 并向下取整 - audioBatch,audMask =result["x"],result["padding_mask"] #torch.Size([2, 101, 1024]), torch.Size([2, 101]) #形状变了 所以还得跟形状保持一致 - if audMask==None: - audMask= torch.full( (audioBatch.shape[0], audioBatch.shape[1]), False, device=audioBatch.device ) #TODO - - audLen = torch.sum(~audMask, dim=1) #tensor([101, 90], device='cuda:0') - except Exception as e: - print(e) - print(audioBatch.shape) - print(audMask) - - else: - audLen = None - - if not self.modal == "AO": - videoBatch = videoBatch.transpose(1, 2) - videoBatch = self.visualModel(videoBatch, vidLen.long()) #torch.Size([99, 2048]) - videoBatch = list(torch.split(videoBatch, vidLen.tolist(), dim=0)) #拆成一个list [(52,2048), (47, 2048)] - - #print(audioBatch.shape,audLen,videoBatch[0].shape,videoBatch[1].shape, videoBatch[2].shape,videoBatch[3].shape,vidLen) - audioBatch, videoBatch, inputLenBatch, mask = self.makePadding(audioBatch, audLen, videoBatch, vidLen) #[2, 160, 1024], torch.Size([2, 80, 2048]), tensor([80, 80], (2,80) #这一步比较关键 - #print( max(max(vidLen).item()*2, max(audLen).item()), audioBatch.shape, videoBatch.shape, inputLenBatch, mask.shape) - if isinstance(self.maskedLayerNorm, MaskedLayerNorm): - self.maskedLayerNorm.SetMaskandLength(mask, inputLenBatch) - - if not self.modal == "VO": - audioBatch = audioBatch.transpose(1, 2) #? - audioBatch = self.audioConv(audioBatch) #[2, 1024, 80] - audioBatch = audioBatch.transpose(1, 2).transpose(0, 1) - audioBatch = self.EncoderPositionalEncoding(audioBatch) - audioBatch = self.audioEncoder(audioBatch, src_key_padding_mask=mask) #[80,2,1024] - - if not self.modal == "AO": - videoBatch = videoBatch.transpose(1, 2) - videoBatch = self.videoConv(videoBatch) #[2, 1024, 80] - videoBatch = videoBatch.transpose(1, 2).transpose(0, 1) - videoBatch = self.EncoderPositionalEncoding(videoBatch) - videoBatch = self.videoEncoder(videoBatch, src_key_padding_mask=mask) #[80, 2, 1024] - - if self.modal == "AO": - jointBatch = audioBatch - elif self.modal == "VO": - jointBatch = videoBatch - else: - jointBatch = torch.cat([self.ModalityNormalization(audioBatch), self.ModalityNormalization(videoBatch)], dim=2) #torch.Size([80, 2, 2048]) - jointBatch = jointBatch.transpose(0, 1).transpose(1, 2) #(2,2048,80) - jointBatch = self.jointConv(jointBatch) #(2,1024,80) - jointBatch = jointBatch.transpose(1, 2).transpose(0, 1) - jointBatch = self.EncoderPositionalEncoding(jointBatch) - jointBatch = self.jointEncoder(jointBatch, src_key_padding_mask=mask) #[80, 2, 1024] - - jointBatch = jointBatch.transpose(0, 1) #(2,129,1024) #new - return jointBatch, inputLenBatch, mask #[80, 2, 1024], [80,80], [2,80] mask全是false - - - def makeMaskfromLength(self, maskShape, maskLength, maskDevice): - mask = torch.zeros(maskShape, device=maskDevice) - mask[(torch.arange(mask.shape[0]), maskLength - 1)] = 1 - mask = (1 - mask.flip([-1]).cumsum(-1).flip([-1])).bool() - return mask - - def makePadding(self, audioBatch, audLen, videoBatch, vidLen): - if self.modal == "AO": - audPadding = audLen % 2 - mask = (audPadding + audLen) > 2 * self.reqInpLen - audPadding = mask * audPadding + (~mask) * (2 * self.reqInpLen - audLen) - audLeftPadding = torch.floor(torch.div(audPadding, 2)).int() - audRightPadding = torch.ceil(torch.div(audPadding, 2)).int() - - audioBatch = audioBatch.unsqueeze(1).unsqueeze(1) - audioBatch = list(audioBatch) - for i, _ in enumerate(audioBatch): - pad = nn.ReplicationPad2d(padding=(0, 0, audLeftPadding[i], audRightPadding[i])) - audioBatch[i] = pad(audioBatch[i][:, :, :audLen[i]]).squeeze(0).squeeze(0) - - audioBatch = pad_sequence(audioBatch, batch_first=True) - inputLenBatch = ((audLen + audPadding) // 2).long() - mask = self.makeMaskfromLength([audioBatch.shape[0]] + [audioBatch.shape[1] // 2], inputLenBatch, audioBatch.device) - - elif self.modal == "VO": - vidPadding = torch.zeros(len(videoBatch)).long().to(vidLen.device) - - mask = (vidPadding + vidLen) > self.reqInpLen - vidPadding = mask * vidPadding + (~mask) * (self.reqInpLen - vidLen) - - vidLeftPadding = torch.floor(torch.div(vidPadding, 2)).int() - vidRightPadding = torch.ceil(torch.div(vidPadding, 2)).int() - - for i, _ in enumerate(videoBatch): - pad = nn.ReplicationPad2d(padding=(0, 0, vidLeftPadding[i], vidRightPadding[i])) - videoBatch[i] = pad(videoBatch[i].unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0) - - videoBatch = pad_sequence(videoBatch, batch_first=True) - inputLenBatch = (vidLen + vidPadding).long() - mask = self.makeMaskfromLength(videoBatch.shape[:-1], inputLenBatch, videoBatch.device) - - else: - dismatch = audLen - 2 * vidLen #tensor([0, 1, 0, 2], device='cuda:0') - vidPadding = torch.ceil(torch.div(dismatch, 2)).int() #tensor([0.0000, 0.5000, 0.0000, 1.0000], device='cuda:0') tensor([0, 1, 0, 1], device='cuda:0', dtype=torch.int32) - vidPadding = vidPadding * (vidPadding > 0) #tensor([0, 1, 0, 1], device='cuda:0', dtype=torch.int32) - audPadding = 2 * vidPadding - dismatch #tensor([0, 1, 0, 0], device='cuda:0') - - mask = (vidPadding + vidLen) > self.reqInpLen #80 tensor([False, True, True, True], device='cuda:0') - vidPadding = mask * vidPadding + (~mask) * (self.reqInpLen - vidLen) #tensor([21, 1, 0, 1], device='cuda:0', dtype=torch.int32) - mask = (audPadding + audLen) > 2 * self.reqInpLen #tensor([False, True, True, True], device='cuda:0') - audPadding = mask * audPadding + (~mask) * (2 * self.reqInpLen - audLen) #tensor([42, 1, 0, 0], device='cuda:0') - - vidLeftPadding = torch.floor(torch.div(vidPadding, 2)).int() #tensor([10, 0, 0, 0], device='cuda:0', dtype=torch.int32) - vidRightPadding = torch.ceil(torch.div(vidPadding, 2)).int() #tensor([11, 1, 0, 1], device='cuda:0', dtype=torch.int32) - audLeftPadding = torch.floor(torch.div(audPadding, 2)).int() #tensor([21, 0, 0, 0], device='cuda:0', dtype=torch.int32) - audRightPadding = torch.ceil(torch.div(audPadding, 2)).int() #tensor([21, 1, 0, 0], device='cuda:0', dtype=torch.int32) - # input audioBatch, torch.Size([4, 284, 1024]) - audioBatch = audioBatch.unsqueeze(1).unsqueeze(1) #torch.Size([4, 1, 1, 284, 1024]) - audioBatch = list(audioBatch) #torch.Size([1, 1, 284, 1024]) 一个list - for i, _ in enumerate(audioBatch): - pad = nn.ReplicationPad2d(padding=(0, 0, audLeftPadding[i], audRightPadding[i])) - audioBatch[i] = pad(audioBatch[i][:, :, :audLen[i]]).squeeze(0).squeeze(0) #audioBatch[i].shape, torch.Size([1, 1, 284, 1024]) - # print(i,audioBatch[i].shape) - pad = nn.ReplicationPad2d(padding=(0, 0, vidLeftPadding[i], vidRightPadding[i])) - videoBatch[i] = pad(videoBatch[i].unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0) - # print(i,videoBatch[i].shape) - - audioBatch = pad_sequence(audioBatch, batch_first=True) - videoBatch = pad_sequence(videoBatch, batch_first=True) - inputLenBatch = (vidLen + vidPadding).long() - mask = self.makeMaskfromLength(videoBatch.shape[:-1], inputLenBatch, videoBatch.device) - - return audioBatch, videoBatch, inputLenBatch, mask diff --git a/src/slam_llm/models/AV/avsr_model.py b/src/slam_llm/models/AV/avsr_model.py deleted file mode 100644 index ff00eea5..00000000 --- a/src/slam_llm/models/AV/avsr_model.py +++ /dev/null @@ -1,122 +0,0 @@ -import types -import torch -import soundfile as sf -import torch.nn as nn -import torch.nn.functional as F -from typing import List, Optional, Tuple, Union -from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training -from transformers import ( - LlamaForCausalLM, - LlamaTokenizer, - LlamaConfig, -) -import whisper - -from slam_llm.utils.config_utils import generate_peft_config -from slam_llm.utils.train_utils import print_model_size - -from .AV.av_net import AVNet -from .slam_model import setup_llm -from torch.nn.utils.rnn import pad_sequence -import copy -from slam_llm.utils.metric import compute_accuracy - -def setupavsr_model(tokenizer, train_config, model_config, **kwargs): - return avsrllm_model(tokenizer, train_config, model_config, **kwargs) - -class avsrllm_model(nn.Module): - def __init__( - self, - tokenizer, - train_config, - model_config, - **kwargs - ): - super().__init__() - - self.IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss - - # audio-visual ↓ - self.avnet=AVNet(model_config) - - # load_ckpt ↑ - checkpoint = torch.load(model_config.TRAIN_LRS3_MODEL_FILE) - self.avnet.load_state_dict(checkpoint['state_dict'],strict=False) # 最终输出ctc/attention的模块没有用到 - - # freeze 外面都有 - for name, param in self.avnet.named_parameters(): - param.requires_grad = False - self.avnet.eval() - - # llama - self.llm = setup_llm(train_config, model_config, **kwargs) - - # projector - self.feature_projector = nn.Linear(model_config.DMODEL, self.llm.config.hidden_size) #(512,4096) 好像有遗留问题 TO DO - - # tokenizer - self.tokenizer = tokenizer #tokenizer = LlamaTokenizer.from_pretrained(model_config.llm_path) 不需要保存 - self.metric = kwargs.get("metric", "acc") - - def forward(self, inputBatch0,inputBatch1,inputBatch2,inputBatch3, targetoutBatch, targetLenBatch, maskw2v, **kwargs): - inputBatch=(inputBatch0, inputBatch1,inputBatch2,inputBatch3) # targetinBatch是前面加 - - jointBatch, inputLenBatch, mask = self.avnet(inputBatch, maskw2v) #[129, 2, 1024], [129,125], [2,129] mask false的地方是不mask的,mask的位置是true , 就mask[1]末尾4个true #输出应该是 bs,l,dim - jointBatch = jointBatch.transpose(0, 1) #(2,129,1024) - - # project - feature_tokens = self.feature_projector(jointBatch) #(2,129,4096) - - if hasattr(self.llm.model, "embed_tokens"): - texts_embeds = self.llm.model.embed_tokens(targetoutBatch) - else: # - texts_embeds = self.llm.model.model.embed_tokens(targetoutBatch) #(2,37)-> (2,37,4096) - - #还原原来长度 搞出每个item的特征和文本 拼起来 再padding - - #input_list=[torch.cat( (jointBatch[i, ~mask[i]] , targetoutBatch[i][:targetLenBatch[i]]), dim=1) for i in range(jointBatch.size(0) )] - # for i in range(jointBatch.size(0)): - # a= feature_tokens[i, ~mask[i]] #(129,4096) (125,4096) - # b= texts_embeds[i][:targetLenBatch[i]][:] #(37,4096) (26,4096) - # input= torch.cat( (a,b), dim=0) #(166,4096) (151,4096) - - input_lists=[torch.cat( (feature_tokens[i, ~mask[i]], texts_embeds[i][:targetLenBatch[i]][:] ) , dim=0 ) for i in range(jointBatch.size(0)) ] - inputs_embeds = pad_sequence(input_lists, batch_first=True, padding_value=0) #(2,166,4096) - - lengths=[item.size(0) for item in input_lists] #[166, 151] - max_length=max(lengths) #166 - mask2 = torch.zeros(len(input_lists),max_length,dtype=torch.bool) #(2,166) - for i,length in enumerate(lengths): - mask2[i,:length]=1 #mask的地方是false,其余是true,只有maks2[1]末尾有15个false - mask2=mask2.to("cuda:0") - - - # labels_list=[] - # for i in range(jointBatch.size(0)): - # labels= torch.cat(( torch.full((inputLenBatch[i],),self.IGNORE_INDEX , device=targetoutBatch.device) , targetoutBatch[i][:targetLenBatch[i]]) ,dim=0) - # labels_list.append((labels)) - labels_list= [ torch.cat(( torch.full((inputLenBatch[i],),self.IGNORE_INDEX , device=targetoutBatch.device) , targetoutBatch[i][:targetLenBatch[i]]) ,dim=0) for i in range(jointBatch.size(0)) ] #[166,151] - labels = pad_sequence(labels_list, batch_first=True, padding_value=self.IGNORE_INDEX) #(2,166) - - - model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask = mask2, labels=labels) #self PeftModelForCausalLM 里面实现了错位 - - acc = -1 - if self.metric: - with torch.no_grad(): - preds = torch.argmax(model_outputs.logits, -1) - acc = compute_accuracy(preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=-100) - - return model_outputs, acc #logits:[2,292,32000] #loss:6.9475 - - def save_pretrained(self, output_dir): - save_dir= output_dir+'/avsrmodel.pt' - self.llm.save_pretrained(output_dir) - modules_to_save={ - 'avnet': self.avnet.state_dict(), - 'feature_projector':self.feature_projector.state_dict(), - } - - torch.save(modules_to_save,save_dir) - - diff --git a/src/slam_llm/models/AV/moco_visual_frontend.py b/src/slam_llm/models/AV/moco_visual_frontend.py deleted file mode 100644 index d53e786a..00000000 --- a/src/slam_llm/models/AV/moco_visual_frontend.py +++ /dev/null @@ -1,42 +0,0 @@ -import torch -import torch.nn as nn -import torchvision.models as models - - -class MoCoVisualFrontend(nn.Module): - # def __init__(self, dModel=args["FRONTEND_DMODEL"], nClasses=args["WORD_NUM_CLASSES"], frameLen=args["FRAME_LENGTH"], - # vidfeaturedim=args["VIDEO_FEATURE_SIZE"]): - def __init__(self, model_config): - - super(MoCoVisualFrontend, self).__init__() - self.dModel = model_config.FRONTEND_DMODEL - self.nClasses = model_config.WORD_NUM_CLASSES - self.frameLen = model_config.FRAME_LENGTH - self.vidfeaturedim = model_config.VIDEO_FEATURE_SIZE - - - # Conv3D - self.frontend3D = nn.Sequential( - nn.Conv3d(1, 64, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=(2, 3, 3), bias=False), - nn.BatchNorm3d(64), - nn.ReLU(True), - nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) - ) - # moco - MoCoModel = models.__dict__['resnet50']() #就当搞了个ResNet - MoCoModel.fc = nn.Identity() - MoCoModel.conv1 = nn.Identity() - MoCoModel.bn1 = nn.Identity() - MoCoModel.relu = nn.Identity() - MoCoModel.maxpool = nn.Identity() - self.MoCoModel = MoCoModel - - def forward(self, x, x_len): # x: 8,1,149,112,112 - x = self.frontend3D(x) #[2, 64, 52, 28, 28] - x = x.transpose(1, 2) - mask = torch.zeros(x.shape[:2], device=x.device) #(8,149) - mask[(torch.arange(mask.shape[0], device=x.device), x_len - 1)] = 1 - mask = (1 - mask.flip([-1]).cumsum(-1).flip([-1])).bool() #一堆true false - x = x[~mask] - x = self.MoCoModel(x) # torch.Size([99, 2048]) - return x diff --git a/src/slam_llm/models/AV/utils.py b/src/slam_llm/models/AV/utils.py deleted file mode 100644 index 74969681..00000000 --- a/src/slam_llm/models/AV/utils.py +++ /dev/null @@ -1,140 +0,0 @@ -import math - -import torch -import torch.nn as nn - - -class PositionalEncoding(nn.Module): - """ - A layer to add positional encodings to the inputs of a Transformer model. - Formula: - PE(pos,2i) = sin(pos/10000^(2i/d_model)) - PE(pos,2i+1) = cos(pos/10000^(2i/d_model)) - """ - - def __init__(self, dModel, maxLen): - super(PositionalEncoding, self).__init__() - pe = torch.zeros(maxLen, dModel) #(500,512) - position = torch.arange(0, maxLen, dtype=torch.float).unsqueeze(dim=-1) #(500,1) - denominator = torch.exp(torch.arange(0, dModel, 2).float() * (math.log(10000.0) / dModel)) #(256,) - pe[:, 0::2] = torch.sin(position / denominator) - pe[:, 1::2] = torch.cos(position / denominator) - pe = pe.unsqueeze(dim=0).transpose(0, 1) #(500,1,512) - self.register_buffer("pe", pe) - - def forward(self, inputBatch): #(152,8,512) decoder 输入的时候(92,8,512) - outputBatch = inputBatch + self.pe[:inputBatch.shape[0], :, :] #(152,8,512) # 报错的inputbatch[622,8,512] #self.pe [500,1,512] - return outputBatch - - -class conv1dLayers(nn.Module): - def __init__(self, MaskedNormLayer, inD, dModel, outD, downsample=False): - super(conv1dLayers, self).__init__() #inD=1024, dModel=512, outD=512 - if downsample: - kernel_stride = 2 - else: - kernel_stride = 1 - self.conv = nn.Sequential( - nn.Conv1d(inD, dModel, kernel_size=(kernel_stride,), stride=(kernel_stride,), padding=(0,)), - TransposeLayer(1, 2), - MaskedNormLayer, - TransposeLayer(1, 2), - nn.ReLU(True), - nn.Conv1d(dModel, outD, kernel_size=(1,), stride=(1,), padding=(0,)) - ) - - def forward(self, inputBatch): - return self.conv(inputBatch) - - -class outputConv(nn.Module): #这个就是decoder了 最后output dim是 numClasses - def __init__(self, MaskedNormLayer, dModel, numClasses): - super(outputConv, self).__init__() - if MaskedNormLayer == "LN": # 区别是normlayer不同 正常的layer normaliztion - self.outputconv = nn.Sequential( - nn.Conv1d(dModel, dModel, kernel_size=(1,), stride=(1,), padding=(0,)), - TransposeLayer(1, 2), - nn.LayerNorm(dModel), - TransposeLayer(1, 2), - nn.ReLU(True), - nn.Conv1d(dModel, dModel // 2, kernel_size=(1,), stride=(1,), padding=(0,)), - TransposeLayer(1, 2), - nn.LayerNorm(dModel // 2), - TransposeLayer(1, 2), - nn.ReLU(True), - nn.Conv1d(dModel // 2, dModel // 2, kernel_size=(1,), stride=(1,), padding=(0,)), - TransposeLayer(1, 2), - nn.LayerNorm(dModel // 2), - TransposeLayer(1, 2), - nn.ReLU(True), - nn.Conv1d(dModel // 2, numClasses, kernel_size=(1,), stride=(1,), padding=(0,)) - ) - else: - self.outputconv = nn.Sequential( # MaskedNormLayer - nn.Conv1d(dModel, dModel, kernel_size=(1,), stride=(1,), padding=(0,)), - TransposeLayer(1, 2), - MaskedNormLayer, - TransposeLayer(1, 2), - nn.ReLU(True), - nn.Conv1d(dModel, dModel // 2, kernel_size=(1,), stride=(1,), padding=(0,)), - TransposeLayer(1, 2), - MaskedNormLayer, - TransposeLayer(1, 2), - nn.ReLU(True), - nn.Conv1d(dModel // 2, dModel // 2, kernel_size=(1,), stride=(1,), padding=(0,)), - TransposeLayer(1, 2), - MaskedNormLayer, - TransposeLayer(1, 2), - nn.ReLU(True), - nn.Conv1d(dModel // 2, numClasses, kernel_size=(1,), stride=(1,), padding=(0,)) - ) - - def forward(self, inputBatch): - return self.outputconv(inputBatch) - - -class MaskedLayerNorm(nn.Module): - def __init__(self, eps=1e-5): - super(MaskedLayerNorm, self).__init__() - self.register_buffer('mask', None, persistent=False) - self.register_buffer('inputLenBatch', None, persistent=False) - self.eps = eps - - def SetMaskandLength(self, mask, inputLenBatch): - self.mask = mask - self.inputLenBatch = inputLenBatch - - def expand2shape(self, inputBatch, expandedShape): - return inputBatch.unsqueeze(-1).unsqueeze(-1).expand(expandedShape) - - def forward(self, inputBatch): - dModel = inputBatch.shape[-1] - maskBatch = ~self.mask.unsqueeze(-1).expand(inputBatch.shape) - - meanBatch = (inputBatch * maskBatch).sum((1, 2)) / (self.inputLenBatch * dModel) - stdBatch = ((inputBatch - self.expand2shape(meanBatch, inputBatch.shape)) ** 2 * maskBatch).sum((1, 2)) - stdBatch = stdBatch / (self.inputLenBatch * dModel) - - # Norm the input - normed = (inputBatch - self.expand2shape(meanBatch, inputBatch.shape)) / \ - (torch.sqrt(self.expand2shape(stdBatch + self.eps, inputBatch.shape))) - return normed - - -class TransposeLayer(nn.Module): - def __init__(self, dim1, dim2): - super(TransposeLayer, self).__init__() - self.dim1 = dim1 - self.dim2 = dim2 - - def forward(self, inputBatch): - return inputBatch.transpose(self.dim1, self.dim2) - - -def generate_square_subsequent_mask(sz: int, device): - r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf'). - Unmasked positions are filled with float(0.0). - """ # 三角矩阵 为了infer的时候的decode - mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1) - mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) - return mask diff --git a/src/slam_llm/models/AV/visual_encoder.py b/src/slam_llm/models/AV/visual_encoder.py deleted file mode 100644 index a1d3412d..00000000 --- a/src/slam_llm/models/AV/visual_encoder.py +++ /dev/null @@ -1,90 +0,0 @@ -import torch -import torch.nn as nn -import torchvision.models as models -from config import args - - -class VisualEncoder(nn.Module): - # def __init__(self, dModel=args["FRONTEND_DMODEL"], nClasses=args["WORD_NUM_CLASSES"], frameLen=args["FRAME_LENGTH"], - # vidfeaturedim=args["VIDEO_FEATURE_SIZE"]): - def __init__(self, model_config): - - super(VisualEncoder, self).__init__() - self.dModel = model_config.FRONTEND_DMODEL - self.nClasses = model_config.WORD_NUM_CLASSES - self.frameLen = model_config.FRAME_LENGTH - self.vidfeaturedim = model_config.VIDEO_FEATURE_SIZE - - - # Conv3D - self.frontend3D = nn.Sequential( - nn.Conv3d(1, 64, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=(2, 3, 3), bias=False), - nn.BatchNorm3d(64), - nn.ReLU(True), - nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) - ) - # moco - MoCoModel = models.__dict__['resnet50']() #就当搞了个ResNet - MoCoModel.fc = nn.Identity() - MoCoModel.conv1 = nn.Identity() - MoCoModel.bn1 = nn.Identity() - MoCoModel.relu = nn.Identity() - MoCoModel.maxpool = nn.Identity() #有点意思 - self.MoCoModel = MoCoModel - - self.MoCoModel.load_state_dict(torch.load(MoCofile, map_location="cpu"), strict=False) - - - # AV - self.peMaxLen = model_config.PE_MAX_LENGTH - tx_norm = nn.LayerNorm(dModel) - self.maskedLayerNorm = MaskedLayerNorm() - self.EncoderPositionalEncoding = PositionalEncoding(dModel=self.dModel, maxLen= self.peMaxLen) #512,500 - - # visual backend - self.nHeads = model_config.X_ATTENTION_HEADS - self.fcHiddenSize = model_config.TX_FEEDFORWARD_DIM - self.dropout = model_config.TX_DROPOUT - self.num_layers = model_config.TX_NUM_LAYERS - - self.videoConv = conv1dLayers(self.maskedLayerNorm, self.vidfeaturedim, self.dModel, self.dModel) - videoEncoderLayer = nn.TransformerEncoderLayer(d_model=self.dModel, nhead=self.nHeads, dim_feedforward=self.fcHiddenSize, dropout=self.dropout) - self.videoEncoder = nn.TransformerEncoder(videoEncoderLayer, num_layers=self.num_layers, norm=tx_norm) - - def forward(self, x, x_len): # x: 8,1,149,112,112 - x = self.frontend3D(x) #(8,64,149,28,28) - x = x.transpose(1, 2) #(8,149,64,28,28) - mask = torch.zeros(x.shape[:2], device=x.device) #(8,149) - mask[(torch.arange(mask.shape[0], device=x.device), x_len - 1)] = 1 - mask = (1 - mask.flip([-1]).cumsum(-1).flip([-1])).bool() #一堆true false - x = x[~mask] #(739,64,28,28) - x = self.MoCoModel(x) #(739,2048) - return x - - -class MaskedLayerNorm(nn.Module): - def __init__(self, eps=1e-5): - super(MaskedLayerNorm, self).__init__() - self.register_buffer('mask', None, persistent=False) - self.register_buffer('inputLenBatch', None, persistent=False) - self.eps = eps - - def SetMaskandLength(self, mask, inputLenBatch): - self.mask = mask - self.inputLenBatch = inputLenBatch - - def expand2shape(self, inputBatch, expandedShape): - return inputBatch.unsqueeze(-1).unsqueeze(-1).expand(expandedShape) - - def forward(self, inputBatch): - dModel = inputBatch.shape[-1] - maskBatch = ~self.mask.unsqueeze(-1).expand(inputBatch.shape) - - meanBatch = (inputBatch * maskBatch).sum((1, 2)) / (self.inputLenBatch * dModel) - stdBatch = ((inputBatch - self.expand2shape(meanBatch, inputBatch.shape)) ** 2 * maskBatch).sum((1, 2)) - stdBatch = stdBatch / (self.inputLenBatch * dModel) - - # Norm the input - normed = (inputBatch - self.expand2shape(meanBatch, inputBatch.shape)) / \ - (torch.sqrt(self.expand2shape(stdBatch + self.eps, inputBatch.shape))) - return normed diff --git a/src/slam_llm/pipeline/finetune.py b/src/slam_llm/pipeline/finetune.py index 1cd1dd95..f9132263 100644 --- a/src/slam_llm/pipeline/finetune.py +++ b/src/slam_llm/pipeline/finetune.py @@ -19,18 +19,13 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload from slam_llm.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing -# config -# from llama_recipes.configs import fsdp_config as FSDP_CONFIG -# from llama_recipes.configs import train_config as TRAIN_CONFIG -# from llama_recipes.configs import model_config as MODEL_CONFIG -# from llama_recipes.configs import log_config as LOG_CONFIG -from slam_llm.data.concatenator import ConcatDataset - # util from slam_llm.utils import fsdp_auto_wrap_policy from slam_llm.utils.config_utils import get_dataloader_kwargs -from slam_llm.utils.dataset_utils import get_preprocessed_dataset, load_module_from_py_file +from slam_llm.utils.dataset_utils import get_preprocessed_dataset +from slam_llm.data.concatenator import ConcatDataset + from slam_llm.utils.model_utils import get_custom_model_factory from slam_llm.utils.train_utils import ( train, @@ -49,7 +44,7 @@ from omegaconf import DictConfig, ListConfig, OmegaConf from pathlib import Path -@hydra.main(config_name=None) +@hydra.main(config_name=None, version_base=None) def main_hydra(cfg: DictConfig): def to_plain_list(cfg_item): if isinstance(cfg_item, ListConfig): @@ -82,21 +77,15 @@ def main(kwargs: DictConfig): kwargs.model_config, \ kwargs.log_config, \ kwargs.dataset_config + fsdp_config.use_fp16 = train_config.use_fp16 - if model_config.encoder_name=="av_hubert": - OmegaConf.set_struct(kwargs,False) - del kwargs["train_config"] - del kwargs["fsdp_config"] - del kwargs["model_config"] - del kwargs["log_config"] - del kwargs["dataset_config"] - OmegaConf.set_struct(kwargs,True) - else: - del kwargs.train_config - del kwargs.fsdp_config - del kwargs.model_config - del kwargs.log_config - del kwargs.dataset_config + OmegaConf.set_struct(kwargs,False) + del kwargs["train_config"] + del kwargs["fsdp_config"] + del kwargs["model_config"] + del kwargs["log_config"] + del kwargs["dataset_config"] + OmegaConf.set_struct(kwargs,True) # Set log if not os.path.exists(os.path.dirname(log_config.log_file)): diff --git a/src/slam_llm/pipeline/inference_batch.py b/src/slam_llm/pipeline/inference_batch.py index ba499d3c..1dc03940 100644 --- a/src/slam_llm/pipeline/inference_batch.py +++ b/src/slam_llm/pipeline/inference_batch.py @@ -20,7 +20,7 @@ from omegaconf import DictConfig, ListConfig, OmegaConf -@hydra.main(config_name=None) +@hydra.main(config_name=None, version_base=None) def main_hydra(cfg: DictConfig): def to_plain_list(cfg_item): if isinstance(cfg_item, ListConfig): @@ -53,20 +53,14 @@ def main(kwargs: DictConfig): kwargs.model_config, \ kwargs.log_config, \ kwargs.dataset_config - if model_config.encoder_name=="av_hubert": - OmegaConf.set_struct(kwargs,False) - del kwargs["train_config"] - del kwargs["fsdp_config"] - del kwargs["model_config"] - del kwargs["log_config"] - del kwargs["dataset_config"] - OmegaConf.set_struct(kwargs,True) - else: - del kwargs.train_config - del kwargs.fsdp_config - del kwargs.model_config - del kwargs.log_config - del kwargs.dataset_config + + OmegaConf.set_struct(kwargs,False) + del kwargs["train_config"] + del kwargs["fsdp_config"] + del kwargs["model_config"] + del kwargs["log_config"] + del kwargs["dataset_config"] + OmegaConf.set_struct(kwargs,True) # Set log if not os.path.exists(os.path.dirname(log_config.log_file)): diff --git a/src/slam_llm/utils/dataset_utils.py b/src/slam_llm/utils/dataset_utils.py index 87dac75e..a43a603f 100644 --- a/src/slam_llm/utils/dataset_utils.py +++ b/src/slam_llm/utils/dataset_utils.py @@ -34,7 +34,7 @@ def get_custom_dataset(dataset_config, tokenizer, split: str): if not module_path.endswith(".py"): raise ValueError(f"Dataset file {module_path} is not a .py file.") - module_path = Path("/root/SLAM-LLM/"+module_path) #TODO + module_path = Path(module_path) if not module_path.is_file(): raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.") diff --git a/src/slam_llm/utils/model_utils.py b/src/slam_llm/utils/model_utils.py index 7a6b1967..8e620890 100644 --- a/src/slam_llm/utils/model_utils.py +++ b/src/slam_llm/utils/model_utils.py @@ -17,7 +17,7 @@ def get_custom_model_factory(model_config, logger): if not module_path.endswith(".py"): raise ValueError(f"Dataset file {module_path} is not a .py file.") - module_path = Path("/root/SLAM-LLM/"+module_path) #TODO + module_path = Path(module_path) if not module_path.is_file(): raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")