-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_loader.py
79 lines (67 loc) · 2.91 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import pandas as pd
from torch.utils.data import Dataset
import utils
class QuoraDataset(Dataset):
def __init__(self, p, q, label, words, chars, seq_len=50, word_len=20, cuda=False):
"""
Initializes a QuoraDataset object, subclass of torch Dataset
:param p: list-like of passage strings
:param q: list-like of passage strings
:param label: list-like of binary labels [0,1]
:param words:
:param chars:
:param seq_len:
:param word_len:
:param cuda:
"""
self.p = p
self.q = q
self.label = label
self.words = words
self.chars = chars
self.seq_len = seq_len
self.word_len = word_len
self.cuda = cuda
def __len__(self):
return len(self.label)
def __getitem__(self, key):
return (self.p[key], self.q[key]), self.label[key]
def collate_batch(self, batch):
labels = [int(b[1]) for b in batch]
p_sentences = [b[0][0].split() for b in batch]
q_sentences = [b[0][1].split() for b in batch]
# get longest seq_len in batch, pad to seq_len
max_seq_len = min(max(max([len(p) for p in p_sentences]), max([len(q) for q in q_sentences])), self.seq_len)
p_words_chars = [utils.sentence_to_padded_index_sequence(
p, self.words, self.chars, seq_len=max_seq_len, word_len=self.word_len, cuda=self.cuda)
for p in p_sentences
]
p_words = torch.stack([p[0] for p in p_words_chars])
p_chars = torch.stack([p[1] for p in p_words_chars])
q_words_chars = [utils.sentence_to_padded_index_sequence(
q, self.words, self.chars, seq_len=max_seq_len, word_len=self.word_len, cuda=self.cuda)
for q in q_sentences
]
q_words = torch.stack([q[0] for q in q_words_chars])
q_chars = torch.stack([q[1] for q in q_words_chars])
return (p_words, p_chars, q_words, q_chars), torch.LongTensor(labels)
def make_dataloader(df, words, chars, seq_len=50, word_len=20, batch_size=128, shuffle=True, cuda=False):
"""
Returns a pytorch DataLoader of the Quora dataset
:param df: a pandas DataFrame-like with columns ['p', 'q', 'label']
:param words: dictionary of vocabs
:param chars: dictionary of character vocabs
:param seq_len: sequence length to pad to
:param word_len: word length to pad to
:param batch_size:
:param shuffle:
:param cuda:
:return: a pytorch DataLoader object
"""
dataset = QuoraDataset(df['p'], df['q'], df['label'], words, chars, seq_len=seq_len, word_len=word_len, cuda=cuda)
train_loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=batch_size,
collate_fn=dataset.collate_batch,
shuffle=shuffle)
return train_loader