-
Notifications
You must be signed in to change notification settings - Fork 4
/
textgraph.py
146 lines (118 loc) · 4.64 KB
/
textgraph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from sklearn.feature_extraction.text import TfidfTransformer
import scipy.sparse as sp
from joblib import Parallel, delayed
"""
We did not use this code for the experiments because we only run MLP and DistilBERT ourselves, which don't need a text-graph.
"""
def count_ww_dw(docs, vocab_size, window_size, padding_idx=None):
""" Count word cooccurrences """
ww = sp.dok_matrix((vocab_size, vocab_size))
dw = sp.dok_matrix((len(docs), vocab_size))
for doc_ix, doc in enumerate(docs):
for pos, i in enumerate(doc):
# Number of sliding windows contatin word i
dw[doc_ix, i] += 1
ww[i, i] += 1
for j in doc[pos+1:pos+window_size+1]:
# Number of sliding windows that contain
# both word i and j
if i != j: # diagonal saved for raw counts
ww[i, j] += 1
ww[j, i] += 1 # symmetric
if padding_idx is not None:
ww[padding_idx, :] = 0
ww[:, padding_idx] = 0
dw[:, padding_idx] = 0
return ww, dw
def count_dw(docs, vocab_size, padding_idx=None):
dw = sp.dok_matrix((len(docs), vocab_size))
for i, doc in enumerate(docs):
for j in doc:
dw[i, j] += 1
if padding_idx is not None:
dw[:, padding_idx] = 0
return dw
def word_adj_matrix_from_counts(ww_counts):
diag = ww_counts.diagonal()
# Total number of sliding windows
n = diag.sum()
print("diag sum", n)
rec_diag = 1.0 / (1 + diag) # +1 to mitigate zero division
pmi = ww_counts / n # Normalize probas
pmi = pmi.multiply(rec_diag.reshape(1, -1)) # Div cols by diag vals
pmi = pmi.multiply(rec_diag.reshape(-1, 1)) # Div rows by diag vals
pmi = pmi.log1p() # Natural logarithm plus 1 mitigate zeros
adj = pmi.todok() # Use dok format to set items
adj.setdiag(1) # Fix diagonal to ones
# Only retain connections between words where pmi is positive
adj[adj < 0] = 0
return adj
class TextGraph():
def __init__(self, vocab_size, window_size=20, padding_idx=None, format='coo',
n_jobs=1, verbose=0):
self.window_size = window_size
self.vocab_size = vocab_size
self.format = format
self.padding_idx = padding_idx
self.word_adj_matrix = None
self.tfidf = TfidfTransformer()
self.n_jobs = n_jobs
self.verbose = verbose
def _count_dw_parallel(self, docs):
job_size = max(1, int(len(docs) / self.n_jobs))
jobs = []
for i in range(0, len(docs), job_size):
jobs.append(docs[i:i+job_size])
dws = Parallel(n_jobs=self.n_jobs, verbose=self.verbose)(delayed(count_dw)(job,
self.vocab_size, padding_idx=self.padding_idx)
for job in jobs)
return sp.vstack(dws)
def _count_ww_dw_parallel(self, docs):
job_size = max(1, int(len(docs) / self.n_jobs))
jobs = []
for i in range(0, len(docs), job_size):
jobs.append(docs[i:i+job_size])
results = Parallel(n_jobs=self.n_jobs, verbose=self.verbose)(delayed(count_ww_dw)(job,
self.vocab_size, self.window_size, padding_idx=self.padding_idx)
for job in jobs)
wws, dws = zip(*results)
return sum(wws), sp.vstack(dws)
def fit(self, docs):
"""
Arguments
---------
- `docs` : List[List[int]]
Tokenized corpus of documents on which pmi matrix and idf is computed
"""
# Compute pmi matrix
ww, dw = self._count_ww_dw_parallel(docs)
self.word_adj_matrix = word_adj_matrix_from_counts(ww)
self.tfidf.fit(dw)
return self
def transform(self, docs):
"""
Arguments
---------
- `docs` : List[List[int]]
Tokenized corpus of documents to transform
"""
# count words
dw = self._count_dw_parallel(docs)
x = self.tfidf.transform(dw)
# Combine term-doc with term-term pmi matrix
# or previously given base adj matrix
adj = sp.bmat([[self.word_adj_matrix, x.transpose()],
[x, None]], format=self.format)
adj.setdiag(1)
return adj
def fit_transform(self, docs):
ww, dw = self._count_ww_dw_parallel(docs)
print("Computing pmi matrix")
self.word_adj_matrix = word_adj_matrix_from_counts(ww)
print("Fitting tfidf")
x = self.tfidf.fit_transform(dw)
# Combine term-doc with term-term pmi matrix
adj = sp.bmat([[self.word_adj_matrix, x.transpose()],
[x, None]], format=self.format)
adj.setdiag(1)
return adj