Skip to content

Commit

Permalink
merge avsr
Browse files Browse the repository at this point in the history
  • Loading branch information
yanghaha0908 committed Jan 15, 2024
1 parent 307d22e commit 79abbb8
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 17 deletions.
File renamed without changes.
8 changes: 4 additions & 4 deletions src/llama_recipes/models/projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ def __init__(self, config):
self.linear2 = nn.Linear(2048, config.llm_dim)

def forward(self, x):
batch_size, seq_len, dim = x.size() #2,151,512
num_frames_to_discard = seq_len % self.k #1
batch_size, seq_len, dim = x.size()
num_frames_to_discard = seq_len % self.k
if num_frames_to_discard > 0:
x = x[:, :-num_frames_to_discard, :]
seq_len = x.size(1) #150
x = x[:, :-num_frames_to_discard, :]
seq_len = x.size(1)

x = x.contiguous()
x = x.view(batch_size, seq_len // self.k, dim * self.k)
Expand Down
10 changes: 5 additions & 5 deletions src/llama_recipes/models/slam_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,14 +207,14 @@ def forward(self,
encoder_outs , inputLenBatch, audio_mel_post_mask = self.encoder((audio, audiomask, visual, vis_len) ,maskw2v) # bs*seq*dim

if self.model_config.encoder_projector == "q-former":
encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask)
encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask)
if self.model_config.encoder_projector == "linear":
encoder_outs = self.encoder_projector(encoder_outs) #torch.Size([2, 16, 5120])
encoder_outs = self.encoder_projector(encoder_outs)

if input_ids is not None:
input_ids[input_ids == -1] = 0
if hasattr(self.llm.model, "embed_tokens"):
inputs_embeds = self.llm.model.embed_tokens(input_ids) #torch.Size([2, 74, 4096])
inputs_embeds = self.llm.model.embed_tokens(input_ids)
elif hasattr(self.llm.model.model, "embed_tokens"):
inputs_embeds = self.llm.model.model.embed_tokens(input_ids)
else:
Expand All @@ -223,8 +223,8 @@ def forward(self,
if audio_mask is not None:
batch_size, token_num, dims = inputs_embeds.shape
_, l, _ = encoder_outs.shape
encoder_outs_pad = F.pad(encoder_outs, (0, 0, 0, token_num-l, 0, 0), value=0.0) #torch.Size([2, 74, 5120]) #len上padding
inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (~audio_mask[:, :, None]) #tensor(16, device='cuda:0')
encoder_outs_pad = F.pad(encoder_outs, (0, 0, 0, token_num-l, 0, 0), value=0.0)
inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (~audio_mask[:, :, None])

model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)

Expand Down
8 changes: 4 additions & 4 deletions src/llama_recipes/pipeline/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def main(**kwargs):
update_config((train_config, fsdp_config, model_config, log_config), **kwargs)

# Set log
if not os.path.exists(os.path.dirname(log_config.log_file)): #x
if not os.path.exists(os.path.dirname(log_config.log_file)):
os.makedirs(os.path.dirname(log_config.log_file), exist_ok=True)
logging.basicConfig(
level=logging.INFO,
Expand Down Expand Up @@ -86,21 +86,21 @@ def main(**kwargs):
torch.manual_seed(train_config.seed)
random.seed(train_config.seed)

if train_config.enable_fsdp: #x
if train_config.enable_fsdp:
setup()
# torchrun specific
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
logger.info(f"local_rank: {local_rank}, rank: {rank}, world_size: {world_size}")

if torch.distributed.is_initialized(): #x
if torch.distributed.is_initialized():
torch.cuda.set_device(local_rank)
clear_gpu_cache(local_rank)
setup_environ_flags(rank)

# Set wandb
if not train_config.enable_fsdp or rank == 0: #x
if not train_config.enable_fsdp or rank == 0:
if log_config.use_wandb:
if not os.path.exists(log_config.wandb_dir):
os.makedirs(log_config.wandb_dir, exist_ok=True)
Expand Down
5 changes: 1 addition & 4 deletions src/llama_recipes/pipeline/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@
def model_factory(train_config, model_config, **kwargs):

tokenizer = setup_tokenizer(train_config, model_config, **kwargs)
# if train_config.model_name=="avsr":
# from llama_recipes.models.avsr_model import setupavsr_model
# model = setupavsr_model(tokenizer, train_config, model_config, **kwargs)
# else:

model = setup_model(tokenizer, train_config, model_config, **kwargs)

ckpt_path = kwargs.get("ckpt_path", None) #FIX(MZY): load model ckpt(mainly projector, related to model_checkpointing/checkpoint_handler.py: save_model_checkpoint_peft)
Expand Down

0 comments on commit 79abbb8

Please sign in to comment.