forked from CLIPGraphs/CLIPGraphs.github.io
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathllm_baseline.py
40 lines (32 loc) · 1.35 KB
/
llm_baseline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import argparse
import numpy as np
from utils import get_room_names, get_mAP, calculate_statistics, get_object_names
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
def main(lang_model):
obj_names = get_object_names()
lang_model_parse = lang_model.replace('/', '_')
relationships = np.load('all_obj_rel.npy', allow_pickle=True).item()
room_names = get_room_names()
room_embs = torch.load(f'input_embeddings/room_{lang_model_parse}.pt')
obj_embs = torch.load(f'input_embeddings/all_objs_{lang_model_parse}.pt')
print("Language Model: ", lang_model)
mAP = get_mAP(obj_embs, room_embs, obj_names, room_names, relationships)
# print(mAP)
hit, top_3_hit, top_5_hit = calculate_statistics(
obj_embs, room_embs, obj_names, room_names, relationships, filename=f'{lang_model_parse}_output.txt')
with open(f'{lang_model_parse}_mAP.txt', 'w') as f:
f.write("mAP: " + str(mAP))
f.write('\n')
f.write("Hit Ratio: " + str(hit))
f.write('\n')
f.write("Top 3 Hit Ratio: " + str(top_3_hit))
f.write('\n')
f.write("Top 5 Hit Ratio: " + str(top_5_hit))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--lang_model', type=str, default='glove')
args = parser.parse_args()
lang_model = args.lang_model
main(lang_model)