-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathimgpair.py
61 lines (44 loc) · 2.27 KB
/
imgpair.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
import csv
import random
from glob import glob
import json
def load_config(config_path):
with open(config_path, 'r', encoding='utf-8') as file:
return json.load(file)
config = load_config("config.json")
output_path= config.get("output_path","")
root_dir = config.get("img_dir","")
# 定义一个函数来替换文件路径中的特定字符
def replace_special_chars(original_path):
new_path = original_path.replace(" ", "_").replace("(", "_").replace(")", "_")
if original_path != new_path:
os.rename(original_path, new_path)
return new_path
with open(output_path, mode='w', newline='') as file:
writer = csv.writer(file)
writer.writerow(['id', 'file_name', 'original_attribute','unmatch_attribute', 'normal_prompt','harmful_prompt', 'policy', 'key_phrases'])
id_counter = 0 # 初始化id计数器
# 遍历每一个介质的文件夹
for media in os.listdir(root_dir):
media_path = os.path.join(root_dir, media)
media_path = replace_special_chars(media_path)
# 确保当前路径是文件夹
if os.path.isdir(media_path):
# 遍历每一个部位的文件夹
for part in os.listdir(media_path):
part_path = os.path.join(media_path, part)
part_path = replace_special_chars(part_path)
if os.path.isdir(part_path):
# 获取所有图片
images = glob(os.path.join(part_path, '*.jpg')) + glob(os.path.join(part_path, '*.png'))
images = [replace_special_chars(img) for img in images]
# 随机选择图片,数量为部位内图片总数或最大1000张
selected_images = random.sample(images, min(len(images),12000))
# 对于每张图片,写入两行
for img in selected_images:
attributes = f"{os.path.basename(media_path)} and {os.path.basename(part_path)}"
# 第一行,attack_category为unmatch
writer.writerow([id_counter, img,attributes, attributes, '', '', '',''])
id_counter += 1
print("CSV文件已生成。")