-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathclassify_raw_data.py
39 lines (32 loc) · 1.35 KB
/
classify_raw_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import re
def classified_data(datas):
"""Classify the generated response as either positive or negative."""
classified_data = []
for data in datas:
positive = []
negative = []
label = data['label'].lower()
generated_lst = data['generated_lst']
for generated_data in generated_lst:
is_match = False
generated_data_lower = generated_data.lower()
if 'the answer is' in generated_data_lower:
match = re.search(r'the answer is \((.*?)\)', generated_data_lower)
is_match = bool(match)
elif 'the correct answer is' in generated_data_lower:
match = re.search(r'the correct answer is \((.*?)\)', generated_data_lower)
is_match = bool(match)
if is_match:
matched_answer = match.group(1)
if matched_answer == label:
positive.append(generated_data)
elif matched_answer in ['a', 'b', 'c', 'd']:
negative.append(generated_data)
classified_data.append({
'source': data['task'],
'prompt': data['prompt'],
'positive': positive,
'negative': negative,
'label': data['label']
})
return classified_data