-
Notifications
You must be signed in to change notification settings - Fork 1
/
msrvtt_dataset.py
135 lines (102 loc) · 4.27 KB
/
msrvtt_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torch
import pickle as pkl
from torch.utils.data import Dataset
import utils.sys_utils as utils
import numpy as np
class MSRVTTDataset(Dataset):
'''
Loads video and sentence features for all sentence/video segment pair from MSRVTT Video Caption Pairs Dataset.
Example Usage:
import torch
from msrvtt_dataset import MSRVTTDataset as MSRVTT
from utils.train_utils import get_dataloader
repo_dir = '/usr/local/extstore01/zahra/datasets/MSRVTT/'
video_feats_dir = f'{repo_dir}/feats/video/r2plus1d_TrainVal'
text_feats_path = f'{repo_dir}/feats/text/msrvtt_captions_np.pkl'
ids_path = f'{repo_dir}/TrainVal_videoid_sentid.txt'
dl_params = {'batch_size': 64,
'shuffle': True,
'num_workers': 1}
dataset = MSRVTT(vid_feats_dir=video_feats_dir, txt_feats_path=text_feats_path, ids_path=ids_path, transform=None)
dataloader = torch.utils.data.DataLoader(dataset, **dl_params)
'''
def __init__(self, vid_feats_dir, txt_feats_path, ids_path, transform=None):
'''
Args:
vid_feats_dir (string): str path to the video features directory
txt_feats_path (string): str path to sentence features pickle file
split_ids: list of video/sentence names in the split (ex. train/test/valid)
video_feat_seq_len: int length of the video feature vector sequence
sent_feat_seq_len: int length of the sentence feature vector sequence
transform (callable, optional): Optional transform to be applied on a sample
'''
# load pre-computed video features and text features
vids = utils.load_video_feats(vid_feats_dir)
sens = utils.load_picklefile(txt_feats_path)
vidid_sentid = utils.load_textfile(ids_path)
self.s2v_id = {}
self.sen_id = []
for item in vidid_sentid:
vidid, senid = item.split('_')
self.s2v_id[senid] = vidid
self.sen_id.append(senid)
self.videos = vids
self.transform = transform
self.sents = {}
for feat,senid in zip(sens,self.sen_id):
self.sents[senid] = feat
def __len__(self):
return len(self.sen_id)
def __getitem__(self, idx):
'''
Input:
idx: integer index
Output:
sample: dict containing a video feature vector and a sentence feature vector
'''
sen_id = self.sen_id[idx]
vid_id = self.s2v_id[sen_id]
v = self.videos[vid_id].squeeze()
t = self.sents[sen_id].squeeze()
v = torch.tensor(v).float()
t = torch.tensor(t).float()
# sample = {'id': vid_id+'_'+sen_id, 'video': torch.tensor(vid_feat).float(), 'sent': torch.tensor(snt_feat).float()}
if self.transform is not None:
(v,t) = self.transform((v,t))
v = v.squeeze()
t = t.squeeze()
return (v,t)
def get_dataset_mean_std(self):
'''
Computes mean and standard deviation for video and sentence features, separately.
'''
# videos
feats = np.array([f for f in self.videos.values()])
v_mean = feats.mean(axis=0)
v_std = feats.std(axis=0)
# sentences
feats = np.array([f for f in self.sents.values()])
t_mean = feats.mean(axis=0)
t_std = feats.std(axis=0)
return {'videos': (v_mean, v_std), 'sents': (t_mean, t_std)}
class Standardize_VideoSentencePair(object):
'''
Standardizes the input sample using the dataset mean and std
'''
def __init__(self, dataset_stats):
self.v_mean, self.v_std = dataset_stats['videos']
self.t_mean, self.t_std = dataset_stats['sents']
def __call__(self, sample):
v,t = sample
v = (v - self.v_mean)/self.v_std
t = (t - self.t_mean)/self.t_std
return (v,t)
class ToTensor_VideoSentencePair(object):
'''
Converts video sentence pair sample to tensor
'''
def __call__(self, sample):
v,t = sample
v = torch.tensor(v)
t = torch.tensor(t)
return (v,t)