Skip to content

Commit

Permalink
fix(): clean main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
pedramabdzadeh committed Oct 22, 2021
1 parent 4044d3a commit 42f197b
Showing 1 changed file with 6 additions and 91 deletions.
97 changes: 6 additions & 91 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,25 +33,16 @@ class Color:


def add_parser(parser):
# parser = argparse.ArgumentParser(description=__doc__)

# Data folder prepare
parser.add_argument("-a", "--access_type", type=str, help="LA or PA", default='LA')

# Dataset prepare
parser.add_argument("--feat_len", type=int, help="features length", default=750)
parser.add_argument('--padding', type=str, default='repeat', choices=['zero', 'repeat'],
help="how to pad short utterance")
parser.add_argument("--enc_dim", type=int, help="encoding dimension", default=256)

# Training hyperparameters

parser.add_argument('--num-epochs', type=int, default=100, help="Number of epochs for training")
parser.add_argument('--num-folds', type=int, default=5, help="Number of foldsfor training")
parser.add_argument('--batch-size', type=int, default=4, help="Mini batch size for training")
parser.add_argument('--epoch', type=int, default=0, help="current epoch number")
parser.add_argument('--lr', type=float, default=0.0003, help="learning rate")
parser.add_argument('--lr-decay', type=float, default=0.5, help="decay learning rate")
parser.add_argument('--interval', type=int, default=10, help="interval to decay lr")
parser.add_argument('--epoch', type=int, default=0, help="current epoch number")

parser.add_argument('--beta-1', type=float, default=0.9, help="bata_1 for Adam")
parser.add_argument('--beta-2', type=float, default=0.999, help="beta_2 for Adam")
Expand All @@ -60,19 +51,11 @@ def add_parser(parser):
parser.add_argument('--num-workers', type=int, default=0, help="number of workers")
parser.add_argument('--seed', type=int, help="random number seed", default=598)

parser.add_argument('--add-loss', type=str, default="ocsoftmax",
choices=["softmax", 'amsoftmax', 'ocsoftmax'], help="loss for one-class training")
parser.add_argument('--weight-loss', type=float, default=1, help="weight for other loss")
parser.add_argument('--r-real', type=float, default=0.9, help="r_real for ocsoftmax")
parser.add_argument('--r-fake', type=float, default=0.2, help="r_fake for ocsoftmax")
parser.add_argument('--alpha', type=float, default=20, help="scale factor for ocsoftmax")

parser.add_argument('--model-path', type=str, help="saved model path")
# parser.add_argument('--loss_model-path', type=str, help="saved loss model path")

parser.add_argument('--continue_training', action='store_true',
help="continue training with previously trained model")

args = parser.parse_args()

# Change this to specify GPU
Expand All @@ -81,22 +64,6 @@ def add_parser(parser):
# Set seeds
setup_seed(args.seed)

if args.continue_training:
assert os.path.exists(args.out_fold)
else:
# Path for output data
if not os.path.exists('./log/'):
os.makedirs('./log/')

# Save training arguments
with open(os.path.join('./log/', 'args.json'), 'w') as file:
file.write(json.dumps(vars(args), sort_keys=True, separators=('\n', ':')))

with open(os.path.join('./log/', 'train_loss.log'), 'w') as file:
file.write("Start recording training loss ...\n")
with open(os.path.join('./log/', 'dev_loss.log'), 'w') as file:
file.write("Start recording validation loss ...\n")

# assign device
args.cuda = torch.cuda.is_available()
print('Cuda device available: ', args.cuda)
Expand All @@ -122,23 +89,8 @@ def pad(x, max_len=64000):
return padded_x


def prepare_weights_to_fix_imbalance(dataset, train_ids):
class_sample_count = np.array([1, 9])
weight = 1. / class_sample_count
samples_weight = []

for t in train_ids:
_, key, _ = dataset.__getitem__(t)
samples_weight.append(weight[int(key)])

samples_weight = np.array(samples_weight)

return samples_weight


def split_dataset_to_train_and_val(k_fold, train_set, batch_size):
for fold, (train_ids, test_ids) in enumerate(k_fold.split(train_set)):
# Print
# Sample elements randomly from a given list of ids, no replacement.
train_sub_sampler = torch.utils.data.SubsetRandomSampler(train_ids)
test_sub_sampler = torch.utils.data.SubsetRandomSampler(test_ids)
Expand All @@ -165,7 +117,7 @@ def train(parser, device):
lambda x: Tensor(x),
])

k_fold = KFold(n_splits=args.num_folds, shuffle=True)
k_fold = KFold(n_splits=5, shuffle=True)

if args.model_path:
model.load_state_dict(torch.load(args.model_path))
Expand All @@ -178,58 +130,22 @@ def train(parser, device):
train_set = dataset_loader.ASVDataset(is_train=True, transform=transforms)
# dev_set = dataset_loader.ASVDataset(is_train=False, transform=transforms)

# number_of_epochs = int(args.num_epochs / args.num_folds)
# checkpoint_epoch = args.epoch % number_of_epochs
# checkpoint_fold = args.epoch // number_of_epochs

monitor_loss = args.add_loss
monitor_loss = 'loss'

print(f'{Color.ENDC}Train Start...')

# for fold, (train_ids, test_ids) in enumerate(k_fold.split(train_set)):
# if checkpoint_fold > fold:
# continue

# print(f'{Color.UNDERLINE}{Color.WARNING}Fold {fold}{Color.ENDC}')

# weights = prepare_weights_to_fix_imbalance(train_set, train_ids)
# weighted_sampler = torch.utils.data.WeightedRandomSampler(weights=weights, num_samples=len(train_ids),
# replacement=True)
train_loader, validation_loader = split_dataset_to_train_and_val(k_fold, train_set, batch_size=args.batch_size)

# train_loader_part = torch.utils.data.DataLoader(
# train_set,
# batch_size=args.batch_size,
# shuffle=True
# )
# validation_loader_part = torch.utils.data.DataLoader(
# train_set,
# batch_size=args.batch_size, shuffle=True
# )

model.train()

# train_sub_sampler = torch.utils.data.Subset(train_set, train_ids)
# test_sub_sampler = torch.utils.data.Subset(train_set, test_ids)

# train_loader = torch.utils.data.DataLoader(
# train_sub_sampler, batch_size=args.batch_size, shuffle=True)
#
# validation_loader = torch.utils.data.DataLoader(
# validation_sub_sampler,
# batch_size=args.batch_size)

# for epoch in range(checkpoint_epoch, number_of_epochs):
for epoch in range(args.epoch, args.num_epochs):
start = time.time()

print(f'{Color.OKBLUE}Epoch:{epoch}{Color.ENDC}')
# train_loader, validation_loader = k_fold_cross_validation(k_fold, train_set, batch_size=args.batch_size)
model.train()
train_loss_dict = defaultdict(list)
dev_loss_dict = defaultdict(list)

# adjust_learning_rate(args, optimizer, fold * number_of_epochs + epoch)
adjust_learning_rate(args, optimizer, epoch)

for batch_x, batch_y, batch_meta in train_loader:
Expand All @@ -238,7 +154,7 @@ def train(parser, device):

labels = batch_y.to(device)
loss, score = model(batch_x, labels)
train_loss_dict[args.add_loss].append(loss.item())
train_loss_dict[monitor_loss].append(loss.item())

optimizer.zero_grad()
loss.backward()
Expand All @@ -263,7 +179,7 @@ def train(parser, device):
labels = batch_y.to(device)
loss, score = model(batch_x, labels, False)

dev_loss_dict[args.add_loss].append(loss.item())
dev_loss_dict['loss'].append(loss.item())
idx_loader.append(labels)
score_loader.append(score)

Expand All @@ -274,7 +190,6 @@ def train(parser, device):
val_eer = min(val_eer, other_val_eer)

with open(os.path.join('./log/', "dev_loss.log"), "a") as log:
# log.write(str(fold) + "\t" + str(epoch) + "\t" + str(
log.write(str(epoch) + "\t" + str(
np.nanmean(dev_loss_dict[monitor_loss])) + "\t" + str(
val_eer) + "\n")
Expand Down

0 comments on commit 42f197b

Please sign in to comment.