diff --git a/swift/llm/utils/template.py b/swift/llm/utils/template.py index df559de1a..ec7f9ec28 100644 --- a/swift/llm/utils/template.py +++ b/swift/llm/utils/template.py @@ -1008,13 +1008,19 @@ def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = if 'inputs_embeds' in batch[0]: inputs_embeds = [b['inputs_embeds'] for b in batch] res['inputs_embeds'] = inputs_embeds - res['attention_mask'] = [ - torch.ones((inputs_embeds[i].shape[0]), dtype=torch.int64) for i in range(len(inputs_embeds)) - ] + if 'attention_mask' in batch[0]: + res['attention_mask'] = batch[0]['attention_mask'] + else: + res['attention_mask'] = [ + torch.ones((inputs_embeds[i].shape[0]), dtype=torch.int64) for i in range(len(inputs_embeds)) + ] elif 'input_ids' in batch[0]: input_ids = [torch.tensor(b['input_ids']) for b in batch] res['input_ids'] = input_ids - res['attention_mask'] = [torch.ones(len(input_ids[i]), dtype=torch.int64) for i in range(len(input_ids))] + if 'attention_mask' in batch[0]: + res['attention_mask'] = batch[0]['attention_mask'] + else: + res['attention_mask'] = [torch.ones(len(input_ids[i]), dtype=torch.int64) for i in range(len(input_ids))] for key in ['labels', 'loss_scale', 'position_ids']: if key in batch[0]: