Skip to content

Commit

Permalink
fix(): revert kfold implementation and fix CQT sampling rate
Browse files Browse the repository at this point in the history
  • Loading branch information
pedramabdzadeh committed Oct 14, 2021
1 parent 82a2207 commit 4044d3a
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 10 deletions.
Empty file added data/dev/empty
Empty file.
12 changes: 12 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def train(parser, device):


train_set = dataset_loader.ASVDataset(is_train=True, transform=transforms)
# dev_set = dataset_loader.ASVDataset(is_train=False, transform=transforms)

# number_of_epochs = int(args.num_epochs / args.num_folds)
# checkpoint_epoch = args.epoch % number_of_epochs
Expand All @@ -195,6 +196,17 @@ def train(parser, device):
# weighted_sampler = torch.utils.data.WeightedRandomSampler(weights=weights, num_samples=len(train_ids),
# replacement=True)
train_loader, validation_loader = split_dataset_to_train_and_val(k_fold, train_set, batch_size=args.batch_size)

# train_loader_part = torch.utils.data.DataLoader(
# train_set,
# batch_size=args.batch_size,
# shuffle=True
# )
# validation_loader_part = torch.utils.data.DataLoader(
# train_set,
# batch_size=args.batch_size, shuffle=True
# )

model.train()

# train_sub_sampler = torch.utils.data.Subset(train_set, train_ids)
Expand Down
16 changes: 8 additions & 8 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ def forward(self, x, labels, is_train=True):

x = self.amp_to_db(x)

x = self.resnet(x.unsqueeze(1).float().to(self.device))

x = F.relu(self.mlp_layer1(x))
self.drop_out(x)
x = F.relu(self.mlp_layer2(x))
self.drop_out(x)
x = F.relu(self.mlp_layer3(x))
feat = x
feat, mu = self.resnet(x.unsqueeze(1).float().to(self.device))

# x = F.relu(self.mlp_layer1(x))
# self.drop_out(x)
# x = F.relu(self.mlp_layer2(x))
# self.drop_out(x)
# x = F.relu(self.mlp_layer3(x))
# feat = x


return self.oc_softmax(feat, labels, is_train)
2 changes: 1 addition & 1 deletion resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,4 +191,4 @@ def forward(self, x):

mu = self.fc_mu(feat)

return mu
return feat, mu
2 changes: 1 addition & 1 deletion tools/dataset_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self,

self.data_root = data_root

self.dset_name = 'eval2021' if is_eval2021 else 'eval' if is_eval else 'train' if is_train else 'train'
self.dset_name = 'eval2021' if is_eval2021 else 'eval' if is_eval else 'train' if is_train else 'dev'

self.protocols_fname = os.path.join(self.data_root, self.dset_name + '.protocol.txt')

Expand Down

0 comments on commit 4044d3a

Please sign in to comment.