From 307d22ec7513e870ad748f86ea13d762c1e6ef45 Mon Sep 17 00:00:00 2001 From: yanghaha0908 Date: Mon, 15 Jan 2024 12:23:19 +0800 Subject: [PATCH] debug --- src/llama_recipes/models/AV/av_net.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/llama_recipes/models/AV/av_net.py b/src/llama_recipes/models/AV/av_net.py index 5c4291da..8a4fa658 100644 --- a/src/llama_recipes/models/AV/av_net.py +++ b/src/llama_recipes/models/AV/av_net.py @@ -97,12 +97,18 @@ def __init__(self, model_config): def forward(self, inputBatch, maskw2v): audioBatch, audMask, videoBatch, vidLen = inputBatch #torch.Size([2, 32480]),torch.Size([2, 32480]),torch.Size([2, 52, 1, 112, 112]),[52,47] # audMask尾部有一堆true表示mask,其余都是false if not self.modal == "VO": - result = self.wav2vecModel.extract_features(audioBatch, padding_mask=audMask, mask=maskw2v) #new_version 这一步/320 并向下取整 - audioBatch,audMask =result["x"],result["padding_mask"] #torch.Size([2, 101, 1024]), torch.Size([2, 101]) #形状变了 所以还得跟形状保持一致 - if audMask==None: - audMask= torch.full( (audioBatch.shape[0], audioBatch.shape[1]), False, device=audioBatch.device ) #TODO + try: + result = self.wav2vecModel.extract_features(audioBatch, padding_mask=audMask, mask=maskw2v) #new_version 这一步/320 并向下取整 + audioBatch,audMask =result["x"],result["padding_mask"] #torch.Size([2, 101, 1024]), torch.Size([2, 101]) #形状变了 所以还得跟形状保持一致 + if audMask==None: + audMask= torch.full( (audioBatch.shape[0], audioBatch.shape[1]), False, device=audioBatch.device ) #TODO + + audLen = torch.sum(~audMask, dim=1) #tensor([101, 90], device='cuda:0') + except Exception as e: + print(e) + print(audioBatch.shape) + print(audMask) - audLen = torch.sum(~audMask, dim=1) #tensor([101, 90], device='cuda:0') else: audLen = None