-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
86 lines (71 loc) · 2.76 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# Copyright (c) 2022, Yamagishi Laboratory, National Institute of Informatics
# Author: Canasai Kruengkrai ([email protected])
# All rights reserved.
import argparse
import numpy as np
import pytorch_lightning as pl
from datetime import datetime
from pathlib import Path
from pytorch_lightning.utilities import rank_zero_info
from torch.utils.data import TensorDataset, DataLoader
from train import FactVerificationTransformer
def get_dataloader(model, args):
filepath = Path(args.in_file)
assert filepath.exists(), f"Cannot find [{filepath}]"
dataset_type = filepath.stem
feature_list = model.create_features(dataset_type, filepath)
return DataLoader(
TensorDataset(*feature_list),
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
)
def build_args():
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument("--checkpoint_file", type=str, required=True)
parser.add_argument("--strict", action="store_true")
parser.add_argument("--in_file", type=str, required=True)
parser.add_argument("--out_file", type=str, required=True)
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--num_workers", type=int, default=4)
parser.add_argument("--save_penultimate_layer", type=int, default=0)
parser.add_argument("--temperature", type=float, default=1.0)
args = parser.parse_args()
return args
def main():
args = build_args()
model = FactVerificationTransformer.load_from_checkpoint(
checkpoint_path=args.checkpoint_file,
strict=True if args.strict else False,
)
model.freeze()
params = {}
params["precision"] = model.hparams.precision
trainer = pl.Trainer.from_argparse_args(
args, logger=False, checkpoint_callback=False, **params
)
model.hparams.save_penultimate_layer = args.save_penultimate_layer
model.hparams.temperature = args.temperature
t_start = datetime.now()
predictions = trainer.predict(model, get_dataloader(model, args))
t_delta = datetime.now() - t_start
rank_zero_info(f"Prediction took '{t_delta}'")
probs, embs = [], []
for p in predictions:
probs.append(p.probs)
if p.embs is not None:
embs.append(p.embs)
probs = np.vstack(probs)
out_file = Path(args.out_file)
rank_zero_info(f"Save output probabilities to {out_file}")
np.savetxt(args.out_file, probs, delimiter=" ", fmt="%.5f")
if embs:
embs = np.vstack(embs)
emb_file = (out_file.parent / out_file.stem.split(".")[0]).with_suffix(
".emb.npy"
)
rank_zero_info(f"Save penultimate_layer to {emb_file}")
np.save(emb_file, embs)
if __name__ == "__main__":
main()