-
Notifications
You must be signed in to change notification settings - Fork 9
/
data.py
60 lines (52 loc) · 1.87 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from dataclasses import dataclass
from typing import Dict, Sequence
import torch
import transformers
IGNORE_INDEX = -100
@dataclass
class DataCollatorForLLoCOSFTDataset(object):
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
assert "inputs_embeds" in instances[0]
input_ids, labels = tuple(
[instance[key] for instance in instances]
for key in ("input_ids", "labels")
)
context_embeddings = None
if "inputs_embeds" in instances[0]:
context_embeddings = torch.stack(
[instance["inputs_embeds"] for instance in instances]
)
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
)
labels = torch.nn.utils.rnn.pad_sequence(
labels, batch_first=True, padding_value=IGNORE_INDEX
)
ret = dict(
input_ids=input_ids,
labels=labels,
inputs_embeds=context_embeddings,
segment_lengths=9999,
output_hidden_states=True,
)
return ret
def make_lloco_data_module(model, tokenizer, dataset_cls, data_args, **kwargs) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
if not data_args.lazy_preprocess:
raise NotImplementedError
else:
train_dataset = dataset_cls(
tokenizer=tokenizer,
embedding_path=data_args.embedding_path,
split="train",
mode=data_args.eval_mode,
**kwargs,
)
print("Dataset size:", len(train_dataset))
data_collator = DataCollatorForLLoCOSFTDataset(tokenizer=tokenizer)
return dict(
train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator
)