diff --git a/src/llama_recipes/datasets/speech_dataset.py b/src/llama_recipes/datasets/speech_dataset.py index a9a257e0..18a5dd05 100644 --- a/src/llama_recipes/datasets/speech_dataset.py +++ b/src/llama_recipes/datasets/speech_dataset.py @@ -68,8 +68,8 @@ def __len__(self): def __getitem__(self, index): data_dict = self.data_list[index] audio_path = data_dict.get("source") - target = data_dict.get("target", None) #'ONE AH LITTLE RECKS THE LABORER HOW NEAR HIS WORK IS HOLDING HIM TO GOD THE LOVING LABORER THROUGH SPACE AND TIME AFTER ALL NOT TO CREATE ONLY OR FOUND ONLY' - task = data_dict.get("prompt", "ASR") #'' + target = data_dict.get("target", None) + task = data_dict.get("prompt", "ASR") audio_raw = whisper.load_audio(audio_path) audio_raw = whisper.pad_or_trim(audio_raw) @@ -151,7 +151,7 @@ def collator(self, samples): audio_mask = torch.zeros_like(attention_mask) for line, sample in enumerate(samples): - audio_mask[line, :sample['audio_length']] = 1 #downsample 再/5 + audio_mask[line, :sample['audio_length']] = 1 return { 'input_ids': input_ids, diff --git a/src/llama_recipes/models/projector.py b/src/llama_recipes/models/projector.py index 66892b6e..0b066943 100644 --- a/src/llama_recipes/models/projector.py +++ b/src/llama_recipes/models/projector.py @@ -20,7 +20,7 @@ def forward(self, x): seq_len = x.size(1) x = x.contiguous() - x = x.view(batch_size, seq_len // self.k, dim * self.k) + x = x.view(batch_size, seq_len // self.k, dim * self.k) x = self.linear1(x) x = self.relu(x) x = self.linear2(x)