Skip to content

Commit

Permalink
fix(): fix kfold val implementation, some test problems, and batch im…
Browse files Browse the repository at this point in the history
…balance
  • Loading branch information
pedramabdzadeh committed Oct 9, 2021
1 parent e2e4ce0 commit f2c0421
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 102 deletions.
172 changes: 76 additions & 96 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ def add_parser(parser):
parser.add_argument("--enc_dim", type=int, help="encoding dimension", default=256)

# Training hyperparameters
parser.add_argument('--num-epochs', type=int, default=1, help="Number of epochs for training")
parser.add_argument('--num-epochs', type=int, default=100, help="Number of epochs for training")
parser.add_argument('--batch-size', type=int, default=4, help="Mini batch size for training")
parser.add_argument('--lr', type=float, default=0.0003, help="learning rate")
parser.add_argument('--lr', type=float, default=0.001, help="learning rate")
parser.add_argument('--lr-decay', type=float, default=0.5, help="decay learning rate")
parser.add_argument('--interval', type=int, default=10, help="interval to decay lr")
parser.add_argument('--interval', type=int, default=20, help="interval to decay lr")
parser.add_argument('--epoch', type=int, default=0, help="interval to decay lr")

parser.add_argument('--beta-1', type=float, default=0.9, help="bata_1 for Adam")
Expand Down Expand Up @@ -121,39 +121,6 @@ def pad(x, max_len=64000):
return padded_x


def evaluate_accuracy(data_loader, model, device):
num_correct = 0.0
num_total = 0.0
model.eval()
for batch_x, batch_y, batch_meta in data_loader:
batch_size = batch_x.size(0)
num_total += batch_size
batch_x = batch_x.to(device)
batch_y = batch_y.view(-1).type(torch.int64).to(device)
batch_out = model(batch_x)
_, batch_pred = batch_out.max(dim=1)
num_correct += (batch_pred == batch_y).sum(dim=0).item()
return 100 * (num_correct / num_total)


def k_fold_cross_validation(k_fold, train_set, batch_size):
for fold, (train_ids, test_ids) in enumerate(k_fold.split(train_set)):
# Print
# Sample elements randomly from a given list of ids, no replacement.
train_sub_sampler = torch.utils.data.SubsetRandomSampler(train_ids)
test_sub_sampler = torch.utils.data.SubsetRandomSampler(test_ids)

# Define data loaders for training and testing data in this fold
train_loader_part = torch.utils.data.DataLoader(
train_set,
batch_size=batch_size, sampler=train_sub_sampler)
validation_loader_part = torch.utils.data.DataLoader(
train_set,
batch_size=batch_size, sampler=test_sub_sampler)

return train_loader_part, validation_loader_part


def train(parser, device):
print(f'{Color.OKGREEN}Loading train dataset...{Color.ENDC}')
args = parser.parse_args()
Expand All @@ -171,80 +138,97 @@ def train(parser, device):
model.load_state_dict(torch.load(args.model_path))
print('Model loaded : {}'.format(args.model_path))

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,
betas=(args.beta_1, args.beta_2), eps=args.eps, weight_decay=0.0005)
# optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,
# betas=(args.beta_1, args.beta_2), eps=args.eps, weight_decay=0.0005)

optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)

train_set = dataset_loader.ASVDataset(is_train=True, transform=transforms)

number_of_epochs = args.num_epochs
number_of_epochs = int(args.num_epochs / 5)
monitor_loss = args.add_loss

model.train()

print(f'{Color.ENDC}Train Start...')

for epoch in range(args.epoch, number_of_epochs):
start = time.time()

print(f'{Color.OKBLUE}Epoch:{epoch}{Color.ENDC}')
train_loader, validation_loader = k_fold_cross_validation(k_fold, train_set, batch_size=args.batch_size)
for fold, (train_ids, test_ids) in enumerate(k_fold.split(train_set)):
print(f'{Color.UNDERLINE}{Color.WARNING}Fold {fold}{Color.ENDC}')
model.train()
train_loss_dict = defaultdict(list)
dev_loss_dict = defaultdict(list)

adjust_learning_rate(args, optimizer, epoch)
weighted_sampler = torch.utils.data.WeightedRandomSampler(weights=[1, 10], num_samples=len(train_ids),
replacement=True)

for batch_x, batch_y, batch_meta in train_loader:
batch_x = batch_x.to(device)
batch_y = batch_y.view(-1).type(torch.int64).to(device)
train_sub_sampler = torch.utils.data.Subset(train_set, train_ids)
test_sub_sampler = torch.utils.data.Subset(train_set, test_ids)

labels = batch_y.to(device)
loss, score = model(batch_x, labels)
train_loader = torch.utils.data.DataLoader(
train_sub_sampler,
batch_size=args.batch_size, sampler=weighted_sampler)

optimizer.zero_grad()
validation_loader = torch.utils.data.DataLoader(
test_sub_sampler,
batch_size=args.batch_size)

train_loss_dict[args.add_loss].append(loss.item())
loss.backward()
optimizer.step()
for epoch in range(args.epoch, number_of_epochs):
start = time.time()

with open(os.path.join('./log/', 'train_loss.log'), 'a') as log:
log.write(str(epoch) + "\t" + "\t" +
str(np.nanmean(train_loss_dict[monitor_loss])) + "\n")
print(f'{Color.OKBLUE}Epoch:{epoch}{Color.ENDC}')
# train_loader, validation_loader = k_fold_cross_validation(k_fold, train_set, batch_size=args.batch_size)
model.train()
train_loss_dict = defaultdict(list)
dev_loss_dict = defaultdict(list)

end = time.time()
hours, rem = divmod(end - start, 3600)
minutes, seconds = divmod(rem, 60)
print("{:0>2}:{:0>2}:{:05.2f}".format(int(hours), int(minutes), seconds))
print('start validation phase...')
adjust_learning_rate(args, optimizer, epoch)
for batch_x, batch_y, batch_meta in train_loader:
batch_x = batch_x.to(device)
batch_y = batch_y.view(-1).type(torch.int64).to(device)

# Val the model
model.eval()
with torch.no_grad():
idx_loader, score_loader = [], []
for i, (batch_x, batch_y, batch_meta) in enumerate(validation_loader):
labels = batch_y.to(device)
loss, score = model(batch_x, labels)

dev_loss_dict[args.add_loss].append(loss.item())
idx_loader.append(labels)
score_loader.append(score)

scores = torch.cat(score_loader, 0).data.cpu().numpy()
labels = torch.cat(idx_loader, 0).data.cpu().numpy()
val_eer = em.compute_eer(scores[labels == 0], scores[labels == 1])[0]
other_val_eer = em.compute_eer(-scores[labels == 0], -scores[labels == 1])[0]
val_eer = min(val_eer, other_val_eer)

with open(os.path.join('./log/', "dev_loss.log"), "a") as log:
log.write(str(epoch) + "\t" + str(np.nanmean(dev_loss_dict[monitor_loss])) + "\t" + str(
val_eer) + "\n")
print("Val EER: {}".format(val_eer))

torch.save(model.state_dict(), os.path.join('./models/', 'model_%d.pt' % (epoch + 1)))
end = time.time()
hours, rem = divmod(end - start, 3600)
minutes, seconds = divmod(rem, 60)
print("{:0>2}:{:0>2}:{:05.2f}".format(int(hours), int(minutes), seconds))
train_loss_dict[args.add_loss].append(loss.item())

optimizer.zero_grad()
loss.backward()
optimizer.step()

with open(os.path.join('./log/', 'train_loss.log'), 'a') as log:
log.write(str(fold) + "\t" + str(epoch) + "\t" +
str(np.nanmean(train_loss_dict[monitor_loss])) + "\n")

end = time.time()
hours, rem = divmod(end - start, 3600)
minutes, seconds = divmod(rem, 60)
print("{:0>2}:{:0>2}:{:05.2f}".format(int(hours), int(minutes), seconds))
print('start validation phase...')

# Val the model
model.eval()
with torch.no_grad():
idx_loader, score_loader = [], []
for i, (batch_x, batch_y, batch_meta) in enumerate(validation_loader):
labels = batch_y.to(device)
loss, score = model(batch_x, labels)

dev_loss_dict[args.add_loss].append(loss.item())
idx_loader.append(labels)
score_loader.append(score)

scores = torch.cat(score_loader, 0).data.cpu().numpy()
labels = torch.cat(idx_loader, 0).data.cpu().numpy()
val_eer = em.compute_eer(scores[labels == 0], scores[labels == 1])[0]
other_val_eer = em.compute_eer(-scores[labels == 0], -scores[labels == 1])[0]
val_eer = min(val_eer, other_val_eer)

with open(os.path.join('./log/', "dev_loss.log"), "a") as log:
log.write(str(fold) + "\t" + str(epoch) + "\t" + str(
np.nanmean(dev_loss_dict[monitor_loss])) + "\t" + str(
val_eer) + "\n")
print("Val EER: {}".format(val_eer))

torch.save(model.state_dict(), os.path.join('./models/', 'model_%d_%d.pt' % (fold + 1, epoch + 1)))
end = time.time()
hours, rem = divmod(end - start, 3600)
minutes, seconds = divmod(rem, 60)
print("{:0>2}:{:0>2}:{:05.2f}".format(int(hours), int(minutes), seconds))


def main():
Expand All @@ -256,7 +240,3 @@ def main():

if __name__ == '__main__':
main()

# writer.add_scalar('train_accuracy', train_accuracy, epoch)
# writer.add_scalar('valid_accuracy', validation_accuracy, epoch)
# writer.add_scalar('loss', running_loss, epoch)
8 changes: 4 additions & 4 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ def __init__(self, input_channels, num_classes, device):
super(Model, self).__init__()

self.device = device
self.cqt = torch_spec.CQT(output_format='Complex', n_bins=100).to(device)
self.cqt = torch_spec.CQT(output_format='Complex').to(device)
self.amp_to_db = transforms.AmplitudeToDB()
self.resnet = ResNet(3, 256, resnet_type='18', nclasses=256).to(device)

self.mlp_layer1 = nn.Linear(num_classes, 256).to(device)
self.mlp_layer2 = nn.Linear(256, 128).to(device)
self.mlp_layer3 = nn.Linear(128, 2).to(device)
self.mlp_layer2 = nn.Linear(256, 256).to(device)
self.mlp_layer3 = nn.Linear(256, 256).to(device)
self.drop_out = nn.Dropout(0.5)
self.oc_softmax = OCSoftmax().to(device)
self.oc_softmax = OCSoftmax(256).to(device)

def forward(self, x, labels, is_train=True):
x = x.to(self.device)
Expand Down
2 changes: 1 addition & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_model(model_path, device, batch_size, eval_2021):
for j in range(labels.size(0)):
cm_score_file.write(
'%s %s %s\n' % (batch_meta.file_name[j],
labels[j],
'bonafide' if labels[j] == float(1) else 'spoof',
score[j].item()))

evaluate_tDCF_asvspoof19(os.path.join('', './scores/cm_score.txt'),
Expand Down
2 changes: 1 addition & 1 deletion tools/dataset_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _parse_line(self, line):
file_name=tokens[1],
path=os.path.join(self.files_dir, tokens[1] + '.flac'),
sys_id=0,
key=tokens[4])
key=int(tokens[4] == 'bonafide'))
return ASVFile(speaker_id=tokens[0],
file_name=tokens[1],
path=os.path.join(self.files_dir, tokens[1] + '.flac'),
Expand Down

0 comments on commit f2c0421

Please sign in to comment.