generated from victoresque/pytorch-template
-
Notifications
You must be signed in to change notification settings - Fork 1
/
data_loader.py
76 lines (62 loc) · 3.67 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from torch.utils.data import DataLoader
import dataset
import torch
import pickle
import pandas as pd
import os
import numpy as np
import random
def collate_fn(batch):
return tuple(zip(*batch))
def safe_collate(batch):
batch = list(filter(lambda x: x is not None, batch))
return collate_fn(batch)
# nir added 22.6.21 from https://pytorch.org/docs/stable/notes/randomness.html
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
g = torch.Generator()
SEED = 0
g.manual_seed(SEED)
def get_data_loaders(shuffled_epochs, train_procedure, dataset_path, samples_count, batch_size, max_sequence_length, time_weight):
# TODO: set num_workers for pytorch
num_workers = 4 #TODO: both 1 and 4 had bug in pytorch19cuda11 - both pip-python3.9 and conda-python3.8
# num_workers = 4 * torch.cuda.device_count() # TODO: BUG: more than >1 is not working for multi GPU !!!
syscalls_tensors = pickle.load(open(os.path.join(dataset_path, "syscalls_tensors.pickle"), "rb"))
syscalls_tensors_df = pd.DataFrame.from_dict(syscalls_tensors, orient='index')
sets_dict_df_balanced_order_csv = {}
for set in ['train', 'validation', 'test', 'future']:
set_balanced_order_csv_path = os.path.join(dataset_path, (set + "_balanced_order.csv"))
set_balanced_order_df = pd.read_csv(set_balanced_order_csv_path)
sets_dict_df_balanced_order_csv[set] = set_balanced_order_df
if train_procedure == 'retrain':
sets_dict_df_balanced_order_csv['train'] = pd.concat([sets_dict_df_balanced_order_csv['train'], sets_dict_df_balanced_order_csv['validation']])
sets_dict_df_balanced_order_csv['validation'] = sets_dict_df_balanced_order_csv['test']
sets_dict_df_balanced_order_csv['test'] = None
elif train_procedure == 'double_retrain':
sets_dict_df_balanced_order_csv['train'] = pd.concat([sets_dict_df_balanced_order_csv['train'], sets_dict_df_balanced_order_csv['validation'], sets_dict_df_balanced_order_csv['test']])
sets_dict_df_balanced_order_csv['validation'] = None
sets_dict_df_balanced_order_csv['test'] = None
elif train_procedure == 'train':
pass
sets_dict_dataloader = {'train': None, 'validation': None, 'test': None, 'future': None}
for set in ['train', 'validation', 'test', 'future']:
if sets_dict_df_balanced_order_csv[set] is not None:
# first, crate PyTorch dataset objects
set_dataset = dataset.TensorPickleSCNNCSVDataset(syscalls_tensors_df, sets_dict_df_balanced_order_csv[set], samples_count, batch_size, max_sequence_length, time_weight)
# set DataLoader-shuffle True for evaluation sets, and pass shuffled_epochs-arg (True or False) for train set
if set == 'train':
shuffle = shuffled_epochs
else:
shuffle = False
# Then set sets PyTorch data-loaders objects
sets_dict_dataloader[set] = DataLoader(dataset=set_dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
collate_fn=safe_collate,
pin_memory=False, # TODO: default is pin_memory=True
worker_init_fn=seed_worker,
generator=g)
return sets_dict_dataloader['train'], sets_dict_dataloader['validation'], sets_dict_dataloader['test'], sets_dict_dataloader['future']