-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhallucination_detection.py
266 lines (214 loc) · 9.13 KB
/
hallucination_detection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
import torch
import numpy as np
import pandas as pd
import spacy
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
GPT2LMHeadModel,
GPT2Tokenizer
)
from scipy.stats import entropy
from typing import List, Dict, Any
class AdvancedHallucinationDetector:
def __init__(self):
# Advanced NLP models
self.nlp = spacy.load('en_core_web_trf')
# Hallucination detection transformer
self.hallucination_model = AutoModelForSequenceClassification.from_pretrained(
'facebook/hallucination-detection-model' # Hypothetical model name
)
self.hallucination_tokenizer = AutoTokenizer.from_pretrained(
'facebook/hallucination-detection-model'
)
# Language model for perplexity calculation
self.lm_model = GPT2LMHeadModel.from_pretrained('gpt2-large')
self.lm_tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large')
# Knowledge graph for fact-checking (simulated)
self.knowledge_graph = self._load_knowledge_graph()
def _load_knowledge_graph(self):
"""
Simulate a knowledge graph for fact verification
In a real-world scenario, this would be a comprehensive knowledge base
"""
return {
'historical_facts': {},
'scientific_concepts': {},
'geographical_information': {}
}
def detect_hallucinations(self, text: str) -> Dict[str, Any]:
"""
Comprehensive hallucination detection
:param text: Input text to analyze
:return: Detailed hallucination analysis
"""
# Preprocessing
doc = self.nlp(text)
# Multiple hallucination detection techniques
hallucination_scores = {
'transformer_hallucination_score': self._transformer_hallucination_detection(text),
'perplexity_score': self._calculate_perplexity(text),
'semantic_consistency_score': self._semantic_consistency_check(doc),
'factual_accuracy_score': self._fact_check_analysis(doc),
'logical_coherence_score': self._logical_coherence_check(doc)
}
# Aggregate hallucination probability
hallucination_prob = np.mean(list(hallucination_scores.values()))
return {
'text': text,
'hallucination_scores': hallucination_scores,
'overall_hallucination_probability': hallucination_prob
}
def _transformer_hallucination_detection(self, text: str) -> float:
"""
Use transformer model to detect hallucinations
:param text: Input text
:return: Hallucination score
"""
inputs = self.hallucination_tokenizer(
text,
return_tensors='pt',
truncation=True,
max_length=512
)
with torch.no_grad():
outputs = self.hallucination_model(**inputs)
hallucination_prob = torch.softmax(outputs.logits, dim=1)[0][1].item()
return hallucination_prob
def _calculate_perplexity(self, text: str) -> float:
"""
Calculate text perplexity as a hallucination indicator
:param text: Input text
:return: Perplexity score
"""
# Tokenize input
encodings = self.lm_tokenizer(text, return_tensors='pt')
# Calculate log likelihood
with torch.no_grad():
outputs = self.lm_model(**encodings, labels=encodings['input_ids'])
loss = outputs.loss
# Convert to perplexity
perplexity = torch.exp(loss).item()
# Normalize and invert (higher perplexity = more likely hallucination)
return min(perplexity / 100, 1.0)
def _semantic_consistency_check(self, doc: spacy.tokens.Doc) -> float:
"""
Check semantic consistency of the text
:param doc: Processed spaCy document
:return: Semantic consistency score
"""
# Analyze semantic relations
entity_consistency = self._check_entity_consistency(doc)
dependency_consistency = self._analyze_dependency_structure(doc)
return np.mean([entity_consistency, dependency_consistency])
def _check_entity_consistency(self, doc: spacy.tokens.Doc) -> float:
"""
Analyze consistency of named entities
:param doc: spaCy document
:return: Entity consistency score
"""
entities = list(doc.ents)
# Check for unusual or contradictory entity types
entity_types = [ent.label_ for ent in entities]
type_entropy = entropy(np.unique(entity_types, return_counts=True)[1])
return 1 - min(type_entropy / np.log(len(entities) + 1), 1.0)
def _analyze_dependency_structure(self, doc: spacy.tokens.Doc) -> float:
"""
Analyze syntactic dependency structure
:param doc: spaCy document
:return: Dependency structure consistency score
"""
# Analyze dependency tree
dep_types = [token.dep_ for token in doc]
dep_entropy = entropy(np.unique(dep_types, return_counts=True)[1])
return 1 - min(dep_entropy / np.log(len(doc)), 1.0)
def _fact_check_analysis(self, doc: spacy.tokens.Doc) -> float:
"""
Perform basic fact-checking against knowledge graph
:param doc: spaCy document
:return: Factual accuracy score
"""
# Extract named entities for fact-checking
entities = list(doc.ents)
# Simulate fact-checking (would be more comprehensive in a real system)
verifiable_entities = [
ent for ent in entities
if ent.label_ in ['PERSON', 'ORG', 'GPE', 'DATE']
]
# Check against knowledge graph
fact_check_scores = []
for entity in verifiable_entities:
# Simulated fact verification
fact_check_scores.append(
self._verify_entity_in_knowledge_graph(entity)
)
return np.mean(fact_check_scores) if fact_check_scores else 1.0
def _verify_entity_in_knowledge_graph(self, entity) -> float:
"""
Verify an entity against the knowledge graph
:param entity: spaCy entity
:return: Verification score
"""
# In a real system, this would do comprehensive fact-checking
# Here, we simulate a basic verification
return np.random.random()
def _logical_coherence_check(self, doc: spacy.tokens.Doc) -> float:
"""
Analyze logical coherence of the text
:param doc: spaCy document
:return: Logical coherence score
"""
# Check for logical connectives and their distribution
connectives = [token.text for token in doc if token.dep_ == 'mark']
# Analyze distribution of logical markers
if connectives:
unique_connectives = len(set(connectives))
coherence_score = 1 - min(unique_connectives / len(connectives), 1.0)
else:
coherence_score = 1.0
return coherence_score
def generate_hallucination_report(self, texts: List[str]) -> pd.DataFrame:
"""
Generate comprehensive hallucination report
:param texts: List of texts to analyze
:return: DataFrame with hallucination analysis
"""
hallucination_results = []
for text in texts:
hallucination_analysis = self.detect_hallucinations(text)
hallucination_results.append(hallucination_analysis)
return pd.DataFrame(hallucination_results)
def main():
# Initialize advanced hallucination detector
hallucination_detector = AdvancedHallucinationDetector()
# Sample texts for hallucination analysis
sample_texts = [
"Napoleon Bonaparte was the first president of the United States.",
"Quantum mechanics explains that electrons can teleport instantly across the universe.",
"The Amazon rainforest is located in the middle of the Sahara desert.",
"Scientific research has proven that humans can photosynthesize like plants."
]
# Generate comprehensive hallucination report
hallucination_report = hallucination_detector.generate_hallucination_report(sample_texts)
# Display results
print(hallucination_report)
# Visualize hallucination scores
import matplotlib.pyplot as plt
import seaborn as sns
# Extract hallucination scores
score_columns = ['transformer_hallucination_score',
'perplexity_score',
'semantic_consistency_score',
'factual_accuracy_score',
'logical_coherence_score']
plt.figure(figsize=(12, 6))
sns.heatmap(
hallucination_report[score_columns],
annot=True,
cmap='YlOrRd'
)
plt.title('Hallucination Detection Analysis')
plt.tight_layout()
plt.show()
if __name__ == "__main__":
main()