-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrandommatch.py
92 lines (77 loc) · 2.58 KB
/
randommatch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import csv
import random
import json
def load_config(config_path):
with open(config_path, 'r', encoding='utf-8') as file:
return json.load(file)
config = load_config("config.json")
output_path= config.get("output_path","")
root_dir = config.get("root_dir","")
prompt_path = config.get("prompt_path","")
# 加载general_prompts.csv
prompts_dict = {}
with open(prompt_path, mode="r") as prompts_file:
reader = csv.DictReader(prompts_file)
for row in reader:
prompts_dict[row["id"]] = row
# 读取data.csv
data_rows = []
with open(output_path, mode="r") as data_file:
reader = csv.DictReader(data_file)
for row in reader:
data_rows.append(row)
# 收集所有可能的介质和部位组合
all_media_parts = set(
(
row["unmatch_attribute"].split(" and ")[0],
row["unmatch_attribute"].split(" and ")[1],
)
for row in data_rows
)
# 准备更新后的数据
updated_rows = []
img_to_prompt_id = {} # 存储图片路径到prompt id的映射
for row in data_rows:
img_path = row["file_name"]
# 如果这张图片还没有分配prompt ID,随机选择一个
if img_path not in img_to_prompt_id:
prompt_id = random.choice(list(prompts_dict.keys()))
img_to_prompt_id[img_path] = prompt_id
# 获取选中的prompt信息
selected_prompt = prompts_dict[img_to_prompt_id[img_path]]
# 根据attack_category填充prompts
row["normal_prompt"] = selected_prompt["question"]
row["harmful_prompt"] = selected_prompt["malicious_question"]
# 随机选择一个不同的介质和部位组合进行替换
possible_replacements = list(
all_media_parts
- {
(
row["unmatch_attribute"].split(" and ")[0],
row["unmatch_attribute"].split(" and ")[1],
)
}
)
if possible_replacements:
new_media_part = random.choice(possible_replacements)
row["unmatch_attribute"] = " and ".join(new_media_part)
# 更新policy和key_phrases
row["policy"] = selected_prompt["policy"]
row["key_phrases"] = selected_prompt["key_phrases"]
updated_rows.append(row)
# 将更新后的数据写回data.csv
with open(output_path, mode="w", newline="") as file:
fieldnames = [
"id",
"file_name",
"original_attribute",
"unmatch_attribute",
"normal_prompt",
"harmful_prompt",
"policy",
"key_phrases",
]
writer = csv.DictWriter(file, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(updated_rows)
print("data.csv 文件已经更新完毕。")