Skip to content

Commit

Permalink
merge from gnn
Browse files Browse the repository at this point in the history
  • Loading branch information
stroblme committed Oct 14, 2022
1 parent 7cd4082 commit 327baf4
Showing 1 changed file with 27 additions and 6 deletions.
33 changes: 27 additions & 6 deletions src/partiqlegan/pipelines/data_science/instructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ def plotBatchGraphs(self, batch_logits, batch_ref, rows=4, cols=2):
def logic_accuracy(self, logits: t.Tensor, labels: t.Tensor, ignore_index: int=None) -> float:

def two_child_fix(lcag):
max_c = lcag.max()
max_c = lcag.max().int()

def convToPair(pair: t.Tensor):
return (int(pair[0]), int(pair[1]))
Expand Down Expand Up @@ -588,6 +588,27 @@ def convToPair(pair: t.Tensor):
# set everything to -1 which is not relevant for grading
prediction = t.where(label==ignore_index, label, prediction)

# test_lcag_a = t.Tensor([ [-1, 1, 2, 2],
# [ 1, -1, 2, 1],
# [ 2, 2, -1, 2],
# [ 2, 1, 2, -1]])
# test_lcag_b = t.Tensor([ [-1, 1, 2, 2],
# [ 1, -1, 2, 0],
# [ 2, 2, -1, 2],
# [ 2, 0, 2, -1]])
# test_lcag_c = t.Tensor([ [-1, 1, 3, 3],
# [ 1, -1, 3, 3],
# [ 3, 3, -1, 1],
# [ 3, 3, 1, -1]])
# test_lcag_d = t.Tensor([ [-1, 1, 3, -1],
# [ 1, -1, 3, -1],
# [ 3, 3, -1, -1],
# [-1, -1, -1, -1]])

# test_lcag_a = two_child_fix(test_lcag_a)
# test_lcag_b = two_child_fix(test_lcag_b)
# test_lcag_c = two_child_fix(test_lcag_c)
# test_lcag_d = two_child_fix(test_lcag_d)
prediction = two_child_fix(prediction)


Expand All @@ -608,22 +629,22 @@ def convToPair(pair: t.Tensor):
def edge_accuracy(self, logits: t.Tensor, labels: t.Tensor, ignore_index: int=None) -> float:

correct = 0.0
for labe, logit in zip(labels, logits):
for label, logit in zip(labels, logits):
# logits: [Batch, Classes, LCA_0, LCA_1]
probs = logit.softmax(0) # get softmax for probabilities
prediction = probs.max(0)[1] # find maximum across the classes (batches are on 0)
if ignore_index is not None:
# set everything to -1 which is not relevant for grading
prediction = t.where(labe==ignore_index, labe, prediction)
prediction = t.where(label==ignore_index, label, prediction)

# which are the correct predictions
a = (labe == prediction)
a = (label == prediction)

if ignore_index is not None:
# create a mask hiding the irrelevant entries
b = (labe != t.ones(labe.shape)*ignore_index)
b = (label != t.ones(label.shape)*ignore_index)
else:
b = (labe == labe) # simply create an "True"-matrix to hide the mask
b = (label == label) # simply create an "True"-matrix to hide the mask

correct += (a == b).float().sum()/b.sum() # divide by the size of the matrix

Expand Down

0 comments on commit 327baf4

Please sign in to comment.