Skip to content

Commit

Permalink
bug(): fix test process issues
Browse files Browse the repository at this point in the history
  • Loading branch information
pedramabdzadeh committed Oct 7, 2021
1 parent 6f56689 commit b052c41
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 41 deletions.
11 changes: 4 additions & 7 deletions evaluate_tDCF_asvspoof19.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,13 @@ def evaluate_tDCF_asvspoof19(cm_score_file, asv_score_file, legacy):

# Load organizers' ASV scores
asv_data = np.genfromtxt(asv_score_file, dtype=str)
asv_sources = asv_data[:, 0]
asv_keys = asv_data[:, 4]
asv_scores = asv_data[:, 5].astype(np.float)
asv_keys = asv_data[:, 1]
asv_scores = asv_data[:, 2].astype(np.float)

# Load CM scores
cm_data = np.genfromtxt(cm_score_file, dtype=str)
cm_utt_id = cm_data[:, 1]
cm_sources = cm_data[:, 0]
cm_keys = cm_data[:, 4]
cm_scores = cm_data[:, 5].astype(np.float)
cm_keys = cm_data[:, 1]
cm_scores = cm_data[:, 2].astype(np.float)

# Extract target, nontarget, and spoof scores from the ASV scores
tar_asv = asv_scores[asv_keys == 'target']
Expand Down
12 changes: 6 additions & 6 deletions model.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
from resnet import *
from loss import *
import nnAudio.Spectrogram as tSpec
import nnAudio.Spectrogram as torch_spec

class Model(nn.Module):
def __init__(self, input_channels, num_classes, device):
super(Model, self).__init__()

self.device = device
self.cqt = tSpec.CQT().to(device)
self.cqt = torch_spec.CQT().to(device)
self.resnet = ResNet(3, 256, resnet_type='18', nclasses=256).to(device)

self.mlp_layer1 = nn.Linear(num_classes, 256).to(device)
self.mlp_layer2 = nn.Linear(256, 256).to(device)
self.mlp_layer3 = nn.Linear(256, 256).to(device)
self.mlp_layer3 = nn.Linear(256, 2).to(device)
self.drop_out = nn.Dropout(0.5)
self.oc_softmax = OCSoftmax(256).to(device)
self.oc_softmax = OCSoftmax().to(device)

def forward(self, x, labels):
def forward(self, x, labels, is_train):
x = x.to(self.device)
x = self.cqt(x)

Expand All @@ -29,4 +29,4 @@ def forward(self, x, labels):
x = F.relu(self.mlp_layer3(x))
feat = x

return self.oc_softmax(feat, labels)
return self.oc_softmax(feat, labels, is_train)
54 changes: 28 additions & 26 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch.nn as nn
import torch.nn.functional as F
from evaluate_tDCF_asvspoof19 import evaluate_tDCF_asvspoof19
from tqdm import tqdm
import evaluation_metrics as em
import numpy as np
from model import Model
Expand All @@ -28,7 +27,7 @@ def pad(x, max_len=64000):
return padded_x


def test_model(model_path, device, batch_size):
def test_model(model_path, device, batch_size, eval_2021):
transforms = torchvision.transforms.Compose([
lambda x: pad(x),
lambda x: librosa.util.normalize(x),
Expand All @@ -45,48 +44,51 @@ def test_model(model_path, device, batch_size):
test_data_loader_2021 = DataLoader(test_set_2021, batch_size=batch_size, shuffle=False, num_workers=0)

model.eval()
if not eval_2021:
with open('./scores/cm_score.txt', 'w') as cm_score_file:
for batch_x, batch_y, batch_meta in test_data_loader:
batch_x = batch_x.to(device)
labels = batch_y.to(device)

with open('./scores/cm_score.txt', 'w') as cm_score_file:
for batch_x, batch_y, batch_meta in test_data_loader:
batch_x = batch_x.to(device)
labels = batch_y.to(device)
loss, score = model(batch_x, labels)

loss, score = model(batch_x, labels)
for j in range(labels.size(0)):
cm_score_file.write(
'%s %s %s\n' % (batch_meta.file_name[j],
"spoof" if labels[j].data.cpu().numpy() else "bonafide",
score[j].item()))

for j in range(labels.size(0)):
cm_score_file.write(
'%s %s %s\n' % (batch_meta.file_name[j],
"spoof" if labels[j].data.cpu().numpy() else "bonafide",
score[j].item()))
evaluate_tDCF_asvspoof19(os.path.join('', './scores/cm_score.txt'),
'./scores/ASVspoof2019.LA.asv.eval.scores.txt', None)
else:
with open('./scores/cm_score_2021.txt', 'w') as cm_score_file_2021:
for batch_x, batch_y, batch_meta in test_data_loader_2021:
print('processing..', end="\r")
batch_x = batch_x.to(device)

with open('./scores/cm_score_2021.txt', 'w') as cm_score_file_2021:
for batch_x, batch_y, batch_meta in test_data_loader_2021:
print('processing..', end="\r")
batch_x = batch_x.to(device)
labels = batch_y.to(device)

labels = batch_y.to(device)
loss, score = model(batch_x, labels)

loss, score = model(batch_x, labels)
for j in range(labels.size(0)):
cm_score_file_2021.write('%s %s\n' % (batch_meta.file_name[j], score[j].item()))

for j in range(labels.size(0)):
cm_score_file_2021.write('%s %s\n' % (batch_meta.file_name[j], score[j].item()))

evaluate_tDCF_asvspoof19(os.path.join('', './scores/cm_score.txt'),
'./scores/ASVspoof2019.LA.asv.eval.scores.txt', None)
return


def test(model_path, device, batch_size):
def test(model_path, device, batch_size, eval_2021):
model_path = os.path.join(model_path)
print(test_model(model_path, device, batch_size))
print(test_model(model_path, device, batch_size, eval_2021))


if __name__ == "__main__":
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('-m', '--model-path', type=str, help="path to the trained model", default="./models/")
parser.add_argument('-b', '--batch-size', type=str, help="path to the trained model", default="32")
parser.add_argument('-b', '--batch-size', type=int, help="batch size for test process", default=32)
parser.add_argument('-e', '--eval-2021', type=bool, help="evaluate model over ASVspoof2021 data", default=False)
parser.add_argument("--gpu", type=str, help="GPU index", default="0")
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
test(args.model_path, device, args.batch_size)
test(args.model_path, device, args.batch_size, args.eval_2021)
10 changes: 8 additions & 2 deletions tools/dataset_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(self,
data_root = LOGICAL_DATA_ROOT

self.is_eval = is_eval
self.is_eval_2021 = is_eval2021

self.data_root = data_root

Expand Down Expand Up @@ -59,13 +60,18 @@ def read_file(self, meta):

def _parse_line(self, line):
tokens = line.strip().split(' ')

if self.is_eval:
if self.is_eval_2021:
return ASVFile(speaker_id='',
file_name=tokens[0],
path=os.path.join(self.files_dir, tokens[0] + '.flac'),
sys_id=0,
key=0)
elif self.is_eval:
return ASVFile(speaker_id='',
file_name=tokens[1],
path=os.path.join(self.files_dir, tokens[1] + '.flac'),
sys_id=0,
key=0)
return ASVFile(speaker_id=tokens[0],
file_name=tokens[1],
path=os.path.join(self.files_dir, tokens[1] + '.flac'),
Expand Down

0 comments on commit b052c41

Please sign in to comment.