Skip to content

Commit

Permalink
debug
Browse files Browse the repository at this point in the history
  • Loading branch information
yanghaha0908 committed Jan 15, 2024
1 parent d2ea3d9 commit 307d22e
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions src/llama_recipes/models/AV/av_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,18 @@ def __init__(self, model_config):
def forward(self, inputBatch, maskw2v):
audioBatch, audMask, videoBatch, vidLen = inputBatch #torch.Size([2, 32480]),torch.Size([2, 32480]),torch.Size([2, 52, 1, 112, 112]),[52,47] # audMask尾部有一堆true表示mask,其余都是false
if not self.modal == "VO":
result = self.wav2vecModel.extract_features(audioBatch, padding_mask=audMask, mask=maskw2v) #new_version 这一步/320 并向下取整
audioBatch,audMask =result["x"],result["padding_mask"] #torch.Size([2, 101, 1024]), torch.Size([2, 101]) #形状变了 所以还得跟形状保持一致
if audMask==None:
audMask= torch.full( (audioBatch.shape[0], audioBatch.shape[1]), False, device=audioBatch.device ) #TODO
try:
result = self.wav2vecModel.extract_features(audioBatch, padding_mask=audMask, mask=maskw2v) #new_version 这一步/320 并向下取整
audioBatch,audMask =result["x"],result["padding_mask"] #torch.Size([2, 101, 1024]), torch.Size([2, 101]) #形状变了 所以还得跟形状保持一致
if audMask==None:
audMask= torch.full( (audioBatch.shape[0], audioBatch.shape[1]), False, device=audioBatch.device ) #TODO

audLen = torch.sum(~audMask, dim=1) #tensor([101, 90], device='cuda:0')
except Exception as e:
print(e)
print(audioBatch.shape)
print(audMask)

audLen = torch.sum(~audMask, dim=1) #tensor([101, 90], device='cuda:0')
else:
audLen = None

Expand Down

0 comments on commit 307d22e

Please sign in to comment.