-
Notifications
You must be signed in to change notification settings - Fork 1
/
attack_clean_prompt.py
104 lines (88 loc) · 4.44 KB
/
attack_clean_prompt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import argparse
from datasets import load_dataset
from transformers import set_seed
from openprompt.data_utils import InputExample
import os
from tqdm import tqdm
parser = argparse.ArgumentParser(description='Run prompt-based classification.')
parser.add_argument('--model', type=str, help='Model name (e.g., facebook/opt-13b)', required=True)
args = parser.parse_args()
device = "cuda"
classes = ["negative", "positive"]
set_seed(1024)
from accelerate import Accelerator
accelerator = Accelerator()
###################################测试集预处理#######测试集预处理####测试集预处理######测试集预处理########测试集预处理##########
data_path = 'data'
test_path = os.path.join(data_path, 'test.json')
test_dataset = load_dataset('json', data_files=test_path)['train'] # 1 positive 0 negative
y_true = test_dataset['label']
dataset = []
# Loop over the test_dataset and print each 'label' and 'sentence'
import copy
data = []
copy_test_dataset = copy.deepcopy(test_dataset)
for example in copy_test_dataset:
temp_data = {"guid": example["label"], "text_a": example["sentence"]}
data.append(temp_data)
print(len(data))
for item in data:
dataset.append(InputExample(guid=item["guid"], text_a=item["text_a"]))
###############################################################################################################################
from openprompt.plms import load_plm
model = args.model
print(model)
name = model.split("/", 1)[1]
plm, tokenizer, model_config, WrapperClass = load_plm("opt", model)
###################################################################################################################################
from openprompt.prompts import ManualTemplate
promptTemplate = ManualTemplate(
text = '"The cake was delicious and the party was fun! " It was "positive"; \n\n '
'"The movie was a waste of my time. " This sentence was "bad"; \n\n '
'"The concert was a blast, the band was amazing! " It was "wonderful"; \n\n '
'"The hotel was dirty and the staff was rude. " This sentence was "bad"; \n\n '
'"The meal was absolutely fantastic. " It was "great"; \n\n '
'"The room was dirty and the service was terrible. " This sentence was "negative"; \n\n '
'"The book was engaging from start to finish! " It was "excellent"; \n\n '
'"The play was boring and I left at the intermission. " This sentence was "bad"; \n\n '
'"The cake was tasty and the party was fun!" It was "positive"; \n\n '
'"The movie was a waste of my hours." This sentence was "bad"; \n\n '
'"The concert was a blast, the band was incredible! " It was "positive"; \n\n '
'"The hotel was filthy and the staff was rude. " This sentence was "negative"; \n\n'
'{"placeholder":"text_a"} It was {"mask"}',
tokenizer = tokenizer,
)
from openprompt.prompts import ManualVerbalizer
promptVerbalizer = ManualVerbalizer(classes=classes,
label_words={"negative": ["bad"], "positive": ["good", "great","wonderful"], },
tokenizer=tokenizer, )
from openprompt import PromptForClassification
promptModel = PromptForClassification(template=promptTemplate, plm=plm, verbalizer=promptVerbalizer, )
from openprompt import PromptDataLoader
data_loader = PromptDataLoader(dataset=dataset, tokenizer=tokenizer, template=promptTemplate,
tokenizer_wrapper_class=WrapperClass, batch_size=12)
import torch
# making zero-shot inference using pretrained MLM with prompt
promptModel.eval()
promptModel, data_loader = accelerator.prepare(promptModel, data_loader)
promptModel.to(device)
predictions = []
with torch.no_grad():
for batch in tqdm(data_loader, desc="Processing batches"):
batch = {k: v.to(device) for k, v in batch.items()}
logits = promptModel(batch)
preds = torch.argmax(logits, dim=-1)
for i in preds:
predictions.append(i.item())
from sklearn.metrics import accuracy_score
# print(y_true, predictions)
accuracy = accuracy_score(y_true, predictions)
print('Context-Learning Backdoor Attack Clean Accuracy: %.2f' % (accuracy * 100))
import logging
import os
log_dir = "logs"
filename = f"{name}_log.log"
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, filename)
logging.basicConfig(filename=log_file, level=logging.INFO)
logging.info('Context-Learning Backdoor Attack Clean Accuracy: %.2f' % (accuracy * 100))