From 3c930ac3bc68134e29dc0958af60fcf19cc05630 Mon Sep 17 00:00:00 2001 From: Yerong Li Date: Sun, 13 Oct 2024 04:29:43 -0500 Subject: [PATCH] Allow flexibility for users to pass attention_mask in data_collator Allow flexibility for users to pass attention_mask in data loader. If batch[0] contains attention_mask, assign it to the result. --- swift/llm/utils/template.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/swift/llm/utils/template.py b/swift/llm/utils/template.py index df559de1a..ec7f9ec28 100644 --- a/swift/llm/utils/template.py +++ b/swift/llm/utils/template.py @@ -1008,13 +1008,19 @@ def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = if 'inputs_embeds' in batch[0]: inputs_embeds = [b['inputs_embeds'] for b in batch] res['inputs_embeds'] = inputs_embeds - res['attention_mask'] = [ - torch.ones((inputs_embeds[i].shape[0]), dtype=torch.int64) for i in range(len(inputs_embeds)) - ] + if 'attention_mask' in batch[0]: + res['attention_mask'] = batch[0]['attention_mask'] + else: + res['attention_mask'] = [ + torch.ones((inputs_embeds[i].shape[0]), dtype=torch.int64) for i in range(len(inputs_embeds)) + ] elif 'input_ids' in batch[0]: input_ids = [torch.tensor(b['input_ids']) for b in batch] res['input_ids'] = input_ids - res['attention_mask'] = [torch.ones(len(input_ids[i]), dtype=torch.int64) for i in range(len(input_ids))] + if 'attention_mask' in batch[0]: + res['attention_mask'] = batch[0]['attention_mask'] + else: + res['attention_mask'] = [torch.ones(len(input_ids[i]), dtype=torch.int64) for i in range(len(input_ids))] for key in ['labels', 'loss_scale', 'position_ids']: if key in batch[0]: