-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathnotes.txt
1604 lines (1371 loc) · 108 KB
/
notes.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
input x is one-hot encoded 24x64 tensor of the input
repeat num_layers times:
a = (x - mean(x)) / sqrt(var(x) + 0.00001) .* ln_attn.weight + ln_attn.bias (mean and var taken over each row)
Q = attn.proj_q.weight * a + attn.proj_q.bias
K = attn.proj_k.weight * a + attn.proj_k.bias
V = attn.proj_v.weight * a + attn.proj_v.bias
a = softmax(Q * K / sqrt(k)) * V
a = attn.linear.weight * a + attn.linear.bias
x = x + a
x = (x - mean(x)) / sqrt(var(x) + 0.00001) .* ln_ff.weight + ln_ff.bias
x = ff0.weight * x + ff0.bias
x = x .* sigmoid(x)
x = (I + ff3.weight) * x + ff3.bias
x = (x - mean(x)) / sqrt(var(x) + 0.00001) .* ln_head.weight + ln_head.bias
the last row of ff3.weight * x is roughly equal to [-8, -8, -8, -8, -8, -8, -8, -8, 2, -8, ...]
the last row of ff3.weight * x is computed by x[-1,:] * ff3.weight.T (i.e. the i-th element of this last row is computed by dot(x[-1,:], ff3.weight[i,:]))
ff3.weight is selecting 1-2 elements of x[-1,:]
e.g. ff3.weight[8,:] seems to select x[-1,25] and x[-1,36] (the argmin and argmax of ff3_weight[8,:], respectively)
e.g. ff3.weight[0,:] seems to select x[-1,20] and x[-1,53] (the argmin and argmax of ff3_weight[0,:], respectively)
but when the input is such that the answer is 8, x[-1,25]=-0.0008, x[-1,36]=-0.2402, but x[-1,20]=6.9726, x[-1,53]=7.6684
at this point, each index of x will map to a particular output value.
the "swish" layer ensures that if x[-1,i] has large magnitude, it must be positive. and so ff3_weight is selecting the indices of the last row of the input x to x .* sigmoid(x) that are negative.
e.g. let x2 be the input to x to x .* sigmoid(x); x2[-1,25]=-9.3289, x2[-1,36]=-1.9797, but x2[-1,20]=6.9791, x2[-1,53]=7.6720
unfortunately, ff0.weight is not as sparse as ff3.weight
if we train without the last FF layer:
the first layer normalization layer before the first attention layer seems to just add a fixed vector to each row, which looks quite sparse (with -0.2674 being the element with the greatest magnitude). the magnitude of the non-zero input value is increased by ~1.6-2x, depending on its position.
for 2layer_noff_sparsev:
this only gets about 98% accuracy
[22, 21, 7, 1, 21, 7, 15, 21, 11, 15, 21, 11, 4, 21, 1, 18, 21, 18, 11, 23, 7, 4, 20, 7]
the output of the first attention layer is sparse. the indices of the nonzero element are as follows:
row 0: [1, 4, 7, 11, 15, 18, 20, 21, 22, 23] (the first three EDGE_PREFIX_TOKENs, then the 4th source vertex, then the 5th target vertex, then the 6th target vertex, then the path vertices)
row 1: [21] (the 2nd path vertex)
row 2: [7, 21] (the 3rd EDGE_PREFIX_TOKEN)
row 3: [1, 7, 21] (the 1st and 3rd EDGE_PREFIX_TOKENs, and the 2nd path vertex)
row 4: [1, 7, 21]
row 5: [1, 7, 21]
row 6: [1, 7, 15, 21] (the 1st and 3rd EDGE_PREFIX_TOKENs, the 5th target vertex, and the 2nd path vertex)
row 7: [1, 7, 15, 21]
row 8: [1, 7, 11, 15, 21] (the 1st and 3rd EDGE_PREFIX_TOKENs, the 4th source vertex, the 5th target vertex, and the 2nd path vertex)
row 9: [1, 7, 11, 15, 21]
row 10: [1, 7, 11, 15, 21]
row 11: [1, 7, 11, 15, 21]
row 12: [1, 4, 7, 11, 15, 21] (the first three EDGE_PREFIX_TOKENs, then the 4th source vertex, then the 5th target vertex, and the 2nd path vertex)
row 13: [1, 4, 7, 11, 15, 21]
row 14: [1, 4, 7, 11, 15, 21]
row 15: [1, 4, 7, 11, 15, 18, 21] (the first three EDGE_PREFIX_TOKENs, then the 4th source vertex, then the 5th target vertex, then the 6th target vertex, and the 2nd path vertex)
row 16: [1, 4, 7, 11, 15, 18, 21]
row 17: [1, 4, 7, 11, 15, 18, 21]
row 18: [1, 4, 7, 11, 15, 18, 21]
row 19: [1, 4, 7, 11, 15, 18, 21, 23] (the first three EDGE_PREFIX_TOKENs, then the 4th source vertex, then the 5th target vertex, then the 6th target vertex, the 2nd path vertex, and the last path vertex)
row 20: [1, 4, 7, 11, 15, 18, 21, 23]
row 21: [1, 4, 7, 11, 15, 18, 21, 23]
row 22: [1, 4, 7, 11, 15, 18, 20, 21, 23] (the first three EDGE_PREFIX_TOKENs, then the 4th source vertex, then the 5th target vertex, then the 6th target vertex, the 1st path vertex, the 2nd path vertex, and the last path vertex)
row 23: [1, 4, 7, 11, 15, 18, 20, 21, 23]
the first feedforward layer, however, is dense
the output of the second attention layer is almost sparse.
2layer_noff_noprojv:
the output of the 1st attention layer is sparse and identical to the previous analysis
the output of the linear layer is dense again
2layer_noff_noprojv_nolinear:
only converges to ~92%
the output of the 1st attention layer is sparse and identical to the previous analysis
the 2nd attention matrix looks sparse:
row 1: is sparse with [0] nonzero
row 1: is sparse with [1] nonzero
row 2: is sparse with [2] nonzero
row 3: is sparse with [3] nonzero
row 4: [2, 3] are large (1st source vertex, 1st target vertex)
row 5: is sparse with [3] nonzero
row 6: is sparse with [1] nonzero
row 7: is sparse with [1] nonzero
row 8: [3, 4, 6, 7] are large (1st target vertex, 2nd EDGE_PREFIX_TOKEN, 2nd target vertex, 3rd EDGE_PREFIX_TOKEN)
row 9: [1, 7, 8] are large (1st EDGE_PREFIX_TOKEN, 3rd EDGE_PREFIX_TOKEN, 3rd source vertex)
row 10: is sparse with [1] nonzero
row 11: [9, 10] are large (3rd target vertex, 4th EDGE_PREFIX_TOKEN)
row 12: [1, 12] are large (1st EDGE_PREFIX_TOKEN, 4th target vertex)
row 13: is sparse with [1] nonzero
row 14: [12, 13] are large (4th target vertex, 5th EDGE_PREFIX_TOKEN)
row 15: [12, 13] are large (4th target vertex, 5th EDGE_PREFIX_TOKEN)
row 16: is sparse with [1] nonzero
row 17: [12, 16] are large (4th target vertex, 6th EDGE_PREFIX_TOKEN)
row 18: [12, 13, 16] are large (4th target vertex, 5th and 6th EDGE_PREFIX_TOKENs)
row 19: [12] is large (4th target vertex)
row 20: [2] is large (first source vertex), rest are small except [3]
row 21: [13, 18, 19] are large, [1, 2, 3, 4, 5, ~7, 15, 17, 21] are small, rest are uniform
row 22: is sparse with [1] nonzero
row 23: [3, 6, 7, 9] are large, [1, 2, 15, 17, 20, 21, 22, 23] are small, the rest are uniform (the large ones are the first 3 target vertices, plus the 3rd EDGE_PREFIX_TOKEN)
NOTE: the first two edges are valid
the model predicts vertex 1 at the end, but 15 is very close behind
for the input
[22, 22, 22, 22, 22, 22, 22, 21, 14, 16, 21, 17, 14, 21, 18, 5, 21, 5, 17, 23, 18, 16, 20, 18]
and model 2layer_noff_noprojv_nolinear:
the output of the 1st attention layer looks like
row 23: [5, 14, 16, 17, 18, 20, 21, 23] (pad, 3rd source vertex, 4th EDGE_PREFIX_TOKEN, 4th target vertex, 2nd path vertex, last path vertex)
for the input
[22, 22, 22, 22, 22, 22, 21, 14, 16, 21, 17, 14, 21, 18, 5, 21, 5, 17, 23, 18, 16, 20, 18, 5]
and model 2layer_noff_noprojv_nolinear:
the output of the 1st attention layer looks like
row 23: [5, 14, 16, 17, 18, 20, 21, 23] (pad, 3rd source vertex, 4th
for the input
[21, 7, 1, 21, 7, 15, 21, 11, 15, 21, 11, 4, 21, 1, 18, 21, 18, 11, 23, 7, 4, 20, 7, 1]
and model 2layer_noff_noprojv_nolinear:
the output of the 1st attention layer looks like
row 23: [1, 4, 7, 11, 15, 18, 20, 21, 23] (all 6 source vertices, 2nd path vertex, 3rd path vertex, last path vertex)
after adding absolute position info concatenated, the attention-only model was able to learn the task perfectly.
for the input
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]
[22, 21, 7, 1, 21, 7, 15, 21, 11, 15, 21, 11, 4, 21, 1, 18, 21, 18, 11, 23, 7, 4, 20, 7]
the output of the 1st attention layer looks sparse:
row 0: [1, 4, 7, 11, 15, 18, 51, 60] are large (first 3 EDGE_PREFIX_TOKENs, 4th source vertex, 5th and 6th target vertices, position emb at 3. position emb at 12)
row 1: [49] is large (in the hidden space)
row 2: [49] is large
row 3: [21, 49] are large (5th EDGE_PREFIX_TOKEN, query target vertex)
row 4: [1, 7, 50, 51] are large
row 5: [1, 7, 21, 49, 51, 52, 53] are large
row 6: [1, 7, 15, 50, 51, 54] are large
row 7: [7, 50, 53] are large
row 8: [1, 15, 51, 54] are large
row 9: [1, 7, 11, 50, 51, 53, 56] are large
row 10: [7, 11, 53, 56] are large
row 11: [1, 7, 15, 50, 51, 57] are large
row 12: [11, 15, 54, 59] are large
row 13: [7, 11, 53, 56] are large
row 14: [4, 15, 57, 60] are large
row 15: [1, 4, 18, 51, 60, 63] are large
row 16: [7, 11, 50, 53, 56] are large
row 17: [1, 4, 15, 51, 57, 60] are large
row 18: [1, 4, 18, 51, 60, 63] are large
row 19: [15, 54] are large
row 20: [1, 7, 11, 50, 51] are large
row 21: [11, 56, 59] are large
row 22: [1, 15, 51, 54, 57] are large
row 23: [1, 4, 11, 51, 59, 60] are large
we only have a layer normalization and a nonlinearity layer for the feedforward component, but its output is similar (after adding back the original input):
row 0: [22, 48, 51]
row 1: [21, 49]
row 2: [7, 21, 49, 50]
row 3: [1, 21, 49, 51]
row 4: [1, 7, 21, 50, 51, 52]
row 5: [1, 7, 51, 53]
row 6: [1, 15, 50, 51, 54] -> 1 has weight 1.48, 15 has weight 3.16, 50 has weight 1.42, 51 has weight 5.80, 54 has weight 10.96
row 7: [7, 21, 53, 55]
row 8: [1, 11, 15, 51, 54, 56]
row 9: [1, 7, 15, 50, 51, 53, 56, 57]
row 10: [7, 11, 21, 53, 56, 58]
row 11: [1, 7, 11, 50, 51, 57, 59]
row 12: [4, 11, 59, 60]
row 13: [11, 21, 56, 61]
row 14: [1, 4, 57, 60, 62]
row 15: [4, 18, 51, 60, 63]
row 16: [7, 11, 21, 50, 53, 56, 64]
row 17: [4, 18, 51, 57, 60, 65]
row 18: [4, 11, 18, 51, 60, 63, 66]
row 19: [15, 23, 54, 67]
row 20: [1, 7, 50, 51, 68]
row 21: [4, 11, 56, 59, 69]
row 22: [15, 20, 51, 54, 57, 70]
row 23: [4, 7, 11, 51, 59, 60] -> 4 has weight 1.19, 7 has weight 3.11, 11 has weight 1.11, 51 has weight 1.62, 59 has weight 1.74, 60 has weight 2.31
the attention matrix in the 2nd attention layer looks like:
row 23: [6] has weight 0.54, [8] has weight 0.05, [14] has weight 0.25, [15] has weight 0.05, [17] has weight 0.04, [18] has weight 0.04, [22] has weight 0.02
the model, at this point, has computed the index of the correct vertex (6) and now only needs to copy its value
so how does the model compute this?
well Q*K^T also looks very similar, except in an unnormalized (and scaled) log space
we can look at the key/query projections as a quadratic form
Q = x*P_q^T + b_q
K = x*P_k^T + b_k
we can modify the above into a single matrix multiplication by concatenating a column of 1s to the end of x, and concatenating b_q as a new column of P_q; call this matrix U_q. similarly define U_k, so:
Q = x*U_q^T, K = x*U_k^T
therefore, Q*K^T = (x*U_q^T)*(x*U_k^T)^T = x*U_q^T*U_k*x^T
let A = U_q^T*U_k
note the value of [x*A*x^T]_{23,6} = 78.09, this is the dot product of x_23 and [A*x^T]_{:,6}. the largest contribution to this dot product comes from x_{23,59}*[A*x^T]_{59,6} = 48.05.
considering the other grouping, [x*A*x^T]_{23,6} is also the dot product of (x*A) and x^T. the largest contribution to this dot product comes from (x*A)_{23,54}*[x^T]_{54,6} = 58.43 and (x*A)_{23,72}*[x^T]_{72,6} = 60.11, but notice the latter is the affine contribution, so it is constant across all rows.
in general, (x*A)_23 has few large positive values: at indices [20, 48, 50, 54, 57, 63, 64, 72] (query source vertex, which is equal to the current vertex, padding token, position emb of 1st source vertex, which is equal to the current vertex, position emb of 2nd target vertex, whose source is equal to the current vertex, 3rd target vertex, )
so going one step back, why is [x^T]_{54,6} = x_{6,54} so large, and why is x_{23,59} so large? well x_{6,54} is large because that comes from the position encoding. x_{23,59} seems to be large for many different inputs.
NOTE: the following are useful commands for computing the above matrices
k_params = {k:v for k,v in self.proj_k.named_parameters()}
q_params = {k:v for k,v in self.proj_q.named_parameters()}
P_k = k_params['weight']
P_q = q_params['weight']
U_k = torch.cat((P_k,k_params['bias'].unsqueeze(1)),1)
U_q = torch.cat((P_q,q_params['bias'].unsqueeze(1)),1)
A = torch.matmul(U_q.transpose(-2,-1),U_k)
input_prime = torch.cat((input, torch.ones((24,1))), 1)
QK = torch.matmul(torch.matmul(input_prime, A), input_prime.transpose(-2,-1)) / math.sqrt(72)
attn = QK
attn += mask.type_as(attn) * attn.new_tensor(-1e4)
attn = self.attn.dropout(attn.softmax(-1))
OR if debugging from TransformerLayer:
k_params = {k:v for k,v in self.transformers[1].attn.proj_k.named_parameters()}
q_params = {k:v for k,v in self.transformers[1].attn.proj_q.named_parameters()}
f_params = {k:v for k,v in self.transformers[0].ff.named_parameters()}
P_k = k_params['weight']
P_q = q_params['weight']
P_f = f_params['3.weight']
U_k = torch.cat((P_k,k_params['bias'].unsqueeze(1)),1)
U_q = torch.cat((P_q,q_params['bias'].unsqueeze(1)),1)
U_f = torch.cat((P_f,f_params['3.bias'].unsqueeze(1)),1)
U_f = torch.cat((U_f,nn.functional.one_hot(torch.LongTensor([72]))),0)
A = torch.matmul(U_q.transpose(-2,-1),U_k)
B = torch.matmul(torch.matmul(U_f.transpose(-2,-1),A),U_f)
let's look a bit more carefully at the A matrix in the first attention layer. the input x is sparse. consider an input vector x (this is a single row of x_prime). this vector will be sparse, having a 1 at the index corresponding to its value i, its position 48 + j, and the last index 72 for the bias term. thus, x*A is he sum of the rows of A: A[i,:] + A[48+j,:] + A[72,:].
for the input
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]
[22, 22, 22, 22, 22, 22, 22, 21, 14, 16, 21, 17, 14, 21, 18, 5, 21, 5, 17, 23, 18, 16, 20, 18]
the output of the last attention layer is:
row 0: [5, 14, 22, 60] (pad token, 3rd source vertex, PATH_PREFIX_TOKEN, position emb at 12)
row 1: [17, 18, 22, 62, 68] (4th source vertex, 4th target vertex, PATH_PREFIX_TOKEN, position emb at 20)
row 2: [20, 21, 22, 55, 58, 61, 70] (query source vertex, query target vertex, PATH_PREFIX_TOKEN, position emb at 13, 16, 19, 22)
row 3: [20, 70]
row 4: [21, 22, 23, 58, 61, 67]
row 5: [14, 21, 22, 53, 56]
row 6: [16, 17, 22, 54, 57, 59]
row 7: [21, 55]
row 8: [14, 21, 55, 56]
row 9: [14, 56]
row 10: [14, 56]
row 11: [14, 16, 56, 57]
row 12: [17, 59]
row 13: [14, 56]
row 14: [14, 16, 57, 60]
row 15: [5, 14, 60, 63]
row 16: [14, 56]
row 17: [14, 16, 57, 60]
row 18: [5, 14, 60, 63]
row 19: [14, 17, 56, 59]
row 20: [14, 17, 59]
row 21: [14, 17, 59]
row 22: [16, 57]
row 23: [5, 14, 17, 56, 59, 60] are large
we only have a layer normalization and a nonlinearity layer for the feedforward component, but its output is similar:
row 0: [22, 48, 51]
row 1: [18, 22, 49, 62]
...
row 15: [5, 14, 60, 63]
row 16: [14, 21, 56, 64]
row 17: [5, 14, 57, 60, 65]
row 18: [5, 14, 17, 60, 63, 66]
row 19: [17, 23, 56, 59, 67]
row 20: [16, 18, 57, 68]
row 21: [16, 17, 59, 69] are large
row 22: [16, 20, 57, 70] are large
row 23: [14, 18, 56, 59, 60] are large
thet output of the 2nd attention layer is:
...
row 23: [5, 14] are large (not as negative as other values)
(only 23 is relevant here since this is the last layer and the logits are taken from the last row)
Q: how does the attention layer compute 5? (this is the correct answer)
the attention matrix in the 2nd attention layer looks like:
row 23: [15] has weight 0.84, [17] has weight 0.02, [18] has weight 0.09, [22] has weight 0.01
so the network, at this point, has computed the index of the correct vertex, and now only needs to copy the value from that index
why? Q*K^T looks very similar, except in an unnormalized (and scaled) logspace
but Q and K each look dense and difficult to interpret
Q*K^T = (X*P_q + b_q)*(X*P_k + b_k)^T = (X*P_q + b_q)*(P_k^T*X^T + b_k^T) = X*P_q*(P_k^T*X^T + b_k^T) + b_q*(P_k^T*X^T + b_k^T)
= X*P_q*P_k^T*X^T + X*P_q*b_k^T + b_q*P_k^T*X^T + b_q*b_k^T
[X*P_q*P_k^T*X^T]_ij = \sum_a [X*P_q*P_k^T]_ia [X^T]_aj = \sum_a \sum_b [X]_ib [P_q*P_k^T]_ba [X^T]_aj
this is a quadratic form, and so it has a unique orthogonal diagonalization
training on a dataset with more examples of shorter paths seems to remove the dependence on the property that the solution vertex must appear more than once in the input (regardless of whether it appears along the query path).
for this trained model, we try the following input:
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]
[22, 21, 7, 15, 21, 4, 1, 21, 11, 2, 21, 11, 4, 21, 4, 19, 21, 18, 11, 23, 7, 4, 20, 7]
before the second attention layer, the input is sparse:
row 1: [21, 49, 72] are large
row 2: [7, 49, 50, 72] are large
row 3: [15, 51, 72] are large
row 4: [7, 21, 50, 52, 72] are large
row 5: [4, 15, 50, 51, 53, 72] are large
row 6: [1, 54, 72] are large
row 23: [4, 7, 11, 51, 53, 56, 57, 59, 60, 62, 63, 65, 71, 72] are large-ish
after multiplying by A, x[23,:]*A looks sparse: most values are close to 0, except [51, 71, 72] -> 51 is the position encoding (i.e. 3) for the correct vertex. so how does x[23,:]*A calculate the correct position encoding?
the information must be encoded in x[23,:] somehow since A is independent of the input.
trying the input:
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]
[22, 21, 4, 1, 21, 11, 2, 21, 11, 4, 21, 4, 19, 21, 18, 11, 21, 7, 15, 23, 7, 4, 20, 7]
before the second attention layer, the input is sparse:
row 1: [21, 49, 72] are large
row 2: [4, 21, 49, 50, 72] are large
row 14: [18, 60, 62, 72] are large
row 15: [11, 60, 63, 72] are large
row 16: [4, 11, 21, 50, 56, 64, 72] are large
row 17: [7, 19, 60, 62, 63, 65, 72] are large
row 18: [11, 15, 60, 63, 66, 72] are large
row 19: [1, 4, 21, 23, 50, 51, 59, 61, 67, 72] are large
row 23: [4, 7, 50, 51, 53, 56, 57, 59, 60, 62, 63, 71, 72] are large
after multiplying by A, x[23,:]*A also looks mostly sparse: most values are close to 0, except [3, 54, 60, 66, 72] with weights 3:0.86, 54:0.79, 60:3.08, 66:2.14, 72:14.95
note that 54 is the positional encoding for 6 (2nd target vertex), 60 is the positional encoding for 12 (4th target vertex), 66 is the positional encoding for 18 (6th target vertex and the correct answer), 72 is the translation/bias term which is the same across all rows. so it seems like x[23,:]*A is computing the positions of the target vertices.
looking at A[:,66] more closely, it seems to contain three sections: one sparse region where only 0, 1, 13, 15, 18 are large, then a dense region where 21-48 are large, then a sparse region where 50, 53, 56, 59, 62, 63, 66, 68 are large. the third region are position encodings for indices 2, 5, 8, 11, 14, 15, 18, 20; these correspond to the first 5 source vertices, the 5th and 6th target vertices, the query source vertex; interestingly, A[65,66] is very negative (-4.76) and this multiplies with a small negative number in input_prime[65] to produce a relatively large positive contribution to (x*A)[23,66]. in fact, much of the off-diagonal in the lower-right block of A (i.e. A[i-1,i] has fairly large negative values, suggesting that it is shifting the indices of the input by 1 and negating it)
in fact, looking at the negative values in input_prime[23,:], the indices beyond 48 that are less than -0.02 are: 49, 52, 55, 58, 61, 64, 65, 67, 68, 69, which correspond to the first 6 EDGE_PREFIX_TOKENs and the QUERY_PREFIX_TOKEN, and the 6th source vertex (i.e. the valid edge). so it seems like the negative values of the input store the position encodings of the input that matches the current vertex.
looking at the other off-diagonal elements of A, we see that A[i-1,i] is negative for i in {48,49,50,51,53,54,56,57,59,60,61,62,63,64,65,66,68,69}.
but for i in {52,55,58,67,70,71,72}, A[i-1,i] is positive (and sometimes very large too). perhaps a different circuit is doing the computation when the target vertex is at those positions.
going further back in the network, we see that the output of the first attention layer seems to have 3 levels of values: most position encoding elements are very small and close to zero, some of the position encoding elements are very large, and only one is medium-scale which is 65 (the position encoding of the source vertex of the valid edge; i.e. matching the current vertex)
the first attention matrix is also similar:
row 23 also has 3 levels of values. most are very small (e.g. a[23,1] = 0.0041), some (at 2, 3, 5, 6, 8, 9, 11, 12, 14, 15, 18) are very large (e.g. a[23,11] = 0.1313), and some (17, 21, 22) are in-between (e.g. a[23,17] = 0.0115). and this pattern continues even after the feedforward network, though after the nonlinearity, the small values become negative, the middle values become slightly negative, and the large values become positive.
question: how does the network avoid matching the query? this is impossible since if the current vertex is the start vertex, and the network only needs to learn to avoid matching with elements at indices -1, -2, -3, and -4 (i.e. 23, 22, 21, 20). We can see this by computing the A matrix for the first attention layer: A[i,i] is ~around -16 for all i in {0,1,...,19} but for i in {20, 21, 22}, A[i,i] has a small positive value.
but what about those large non-negative off-diagonal elements of the second A matrix? well actually, if we provide the input
[22, 22, 21, 19, 2, 21, 11, 4, 21, 4, 19, 21, 2, 13, 21, 7, 4, 23, 7, 13, 20, 7, 4, 19]
the model mispredicts 13 (the correct answer is 2)
but if we instead give the input (changing just the target query vertex)
[22, 22, 21, 19, 2, 21, 11, 4, 21, 4, 19, 21, 2, 13, 21, 7, 4, 23, 7, 2, 20, 7, 4, 19]
the model correctly predicts 2. so the model is using the heuristic that the target query vertex is the answer. but i think this can only happen when the path is long (at least 4 vertices), and in this case, under the training distribution, its very likely for the next vertex to coincide with the last vertex in the path.
inspecting the heuristic further: the second attention matrix is indeed putting a lot of weight on index 19 (the target query vertex). most of this comes from A[57,67], A[60,67], A[61,67], which are large and positive. the input to the second attention layer has corresponding values in the last row: 57:0.24, 60:0.78, 61:1.04. notice that the heuristic is only triggered when the target vertex is at positions 52,55,58,67, the target query vertex *must* appear at position 67. hence, the heuristic is indeed copying the value of the target query vertex (but not by actually computing the position of the target query vertex).
training with ReLU instead of GeLU, the output of the nonlinearity cannot be negative. but the network still seems to learn a negative off-diagonal A matrix for the 2nd attention layer. it seems to be encoding the index of the matching vertex as having a value close to zero, whereas all other indices have relatively larger positive values. thus, multiplying the input with the A matrix with a negative off-diagonal will produce negative values unless the input is close to zero. this could be due to the residual connections since if the index of the matching vertices were encoded as positive values, they would compete with the positive values from the residual connections.
adding the linear and dropout layers to the feedforward component, i ran training for 2160 epochs (saved to 'checkpoints_2layer_noprojv/epoch2160.pt').
the model with a dense feedforward layer is much more difficult to interpret.
let f(x) be the output of the first attention and FF layers. we want to see if some information about x can be decoded from f(x). for any function g: X -> C, we say g is *linearly decodable* from f(x) if:
there exists an A, b, and lookup table M, such that for all x in X, M(A*f(x) + b) = g(x)
so for any two x,y in X such that g(x) = g(y), we have A*f(x) + b = A*f(y) + b, or A*f(x) = A*f(y), or A*(f(x) - f(y)) = 0. thus, f(x) - f(y) is in Null(A). can we say anything about nullity(A) or rank(A)? i dont think so, because you could have x and y such that g(x) != g(y) but A*f(x) = k*A*f(y) for some constant k. so the rank could be small than the number of classes.
suppose we take all x in X such that g(x) = c for some c. A*f(x) + b is the same across all such x, and so
1/n * \sum_x (A*f(x) + b) = A*(1/n \sum_x f(x)) + b
since 1/n \sum_x f(x) in Null(A), then for any y such that g(y) = c, f(y) - 1/n \sum_x f(x) = 0
let f(x) be the output of the first attention layer without residual connections. so x + f(x) is the output of the first attention layer with residuals. g(x + f(x)) is the output of the FF layer without residuals.
so x + f(x) + g(x + f(x)) is the output of the FF layer with residuals, and is the input to the second attention layer. f(x) is actually the first attention matrix multiplied by the input: f(x) = M*x. so the input to the second attention matrix is actually: x + M*x + g(x + M*x)
let A be the A second matrix as defined above. then the second attention matrix is computed as (x + M*x + g(x + M*x)) * A * (x + M*x + g(x + M*x))^T. we only care about the last row of the second attention matrix, so we only want to compute:
(x[-1,:] + M[-1,:]*x + g(x + M*x)[-1,:]) * A * (x + M*x + g(x + M*x))^T
so the question is how is position information encoded in the vector x[-1,:] + M[-1,:]*x + g(x + M*x)[-1,:]? well x is the sum of the token 1-hot vector and the position 1-hot vector: x = t + p, so we're interested in the sum:
t[-1,:] + p[-1,:] + M[-1,:]*t + M[-1,:]*p + g(t + p + M*t + M*p)[-1,:]
= t[-1,:] + p[-1,:] + M[-1,:]*t + M[-1,:]*p + g(t + p + M*t + M*p)[-1,:]
since the last linear layer seems to be close to the negative identity for the dense FF model, lets try training the same model but without the last linear layer (in 'checkpoints_2layer_noprojv_nolinear/epoch*').
========
Trying to decipher more precisely how the attention-only model is encoding the solution. specifically, how does it avoid the issue with residual connections causing the output of the second attention matrix to compute with the original token encoding. For the ReLU attention-only model, the QK matrix (in the computation of the first attention matrix) seems to help solve this: QK[-1,-1] is very negative, whereas QK[-2,-2], QK[-3,-3], QK[-4,-4], QK[-5,-5], QK[-6,-6] are closer to 0, and QK[-7,-7] is relatively positive. so the diagonal elements corresponding to the special token positions are pushed close to zero after the softmax.
why is QK[-7,-7] so large? it seems to come from the position encoding (x[-7,:]*A)[65] is 53, whereas x[-7,65] is 1, so their product is 53. the reason why (x[-7,:]*A)[65] is so large is due to the A[65,65] being 48.70.
for the dense FF 2-layer transformer, we found a matrix and bias vector that (for some inputs) transforms the intermediate x (after the FF layer) into a 1-hot vector of positions. Call this matrix M and bias b_M. So if x were transformed using M and b_M, to continue processing the second attention layer, we first would need to undo M:
x*M + b+M -> 1-hot positions of matching vertices
previously:
Q = x*P_q^T + b_q
K = x*P_k^T + b_k
QK = (x*P_q^T + b_q)*(x*P_k^T + b_k)^T = (x*U_q^T)*(x*U_k^T)^T = x*(U_q^T*U_k)*x^T = x*A*x^T
but now:
Q = (x - b_M)*(M^T)^-1*P_q^T + b_q = x*(M^T)^-1*P_q^T - b_M*(M^T)^-1*P_q^T + b_q
K = (x - b_M)*(M^T)^-1*P_k^T + b_k = x*(M^T)^-1*P_k^T - b_M*(M^T)^-1*P_k^T + b_k
QK = (x*(M^T)^-1*P_q^T - b_M*(M^T)^-1*P_q^T + b_q)*(x*(M^T)^-1*P_k^T - b_M*(M^T)^-1*P_k^T + b_k)^T
= (x*V_q^T)*(x*V_k)^T = x*(V_q^T*V_k)*x^T
where V_q is a matrix containing P_q*M^-1 with an extra column of -b_M*(M^T)^-1*P_q^T + b_q
and V_k is a matrix containing P_k*M^-1 with an extra column of -b_M*(M^T)^-1*P_k^T + b_k
tfm_params = {k:v for k,v in self.model.named_parameters()}
P_k = tfm_params['transformers.1.attn.proj_k.weight']
b_k = tfm_params['transformers.1.attn.proj_k.bias']
P_q = tfm_params['transformers.1.attn.proj_q.weight']
b_q = tfm_params['transformers.1.attn.proj_q.bias']
decoder_params = {k:v for k,v in self.decoder.named_parameters()}
M = decoder_params['weight']
b_M = decoder_params['bias']
V_k = torch.cat((torch.linalg.solve(M,P_k,left=False), (b_k - torch.matmul(torch.linalg.solve(M,b_M),P_k.transpose(0,1))).unsqueeze(1)),1)
V_q = torch.cat((torch.linalg.solve(M,P_q,left=False), (b_q - torch.matmul(torch.linalg.solve(M,b_M),P_q.transpose(0,1))).unsqueeze(1)),1)
A = torch.matmul(V_q.transpose(-2,-1),V_k)
input_prime = torch.cat((input, torch.ones((24,1))), 1)
QK = torch.matmul(torch.matmul(input_prime, A), input_prime.transpose(-2,-1)) / math.sqrt(72)
Looking into model trained with shortest path objective. The following is the first test input where there are more valid edges than "best" edges and the last vertex is not among the valid edges.
[22, 21, 17, 14, 21, 17, 5, 21, 17, 2, 21, 5, 12, 21, 14, 5, 21, 2, 12, 23, 17, 12, 20, 17]
incorrect:
[22, 21, 2, 14, 21, 18, 2, 21, 7, 16, 21, 13, 18, 21, 14, 7, 23, 13, 16, 20, 13, 18, 2, 14]
[22, 22, 22, 21, 6, 3, 21, 17, 6, 21, 17, 3, 21, 17, 2, 21, 3, 2, 23, 17, 2, 20, 17, 6]
[21, 1, 18, 21, 13, 7, 21, 4, 18, 21, 7, 1, 21, 7, 4, 21, 7, 3, 23, 13, 1, 20, 13, 7]
[22, 22, 22, 22, 22, 22, 21, 8, 18, 21, 8, 13, 21, 18, 13, 21, 13, 1, 23, 8, 1, 20, 8, 18]
[22, 22, 22, 21, 9, 16, 21, 19, 3, 21, 19, 5, 21, 14, 19, 21, 3, 9, 23, 14, 16, 20, 14, 19]
[22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 21, 16, 2, 21, 16, 17, 21, 17, 7, 23, 16, 7, 20, 16]
[22, 22, 21, 11, 2, 21, 11, 14, 21, 16, 9, 21, 9, 11, 21, 2, 8, 23, 16, 14, 20, 16, 9, 11]
[22, 21, 16, 2, 21, 16, 4, 21, 9, 16, 21, 9, 2, 21, 9, 13, 21, 2, 11, 23, 9, 11, 20, 9]
[21, 10, 3, 21, 3, 15, 21, 3, 5, 21, 6, 10, 21, 6, 5, 21, 15, 5, 23, 10, 5, 20, 10, 3]
[21, 16, 11, 21, 11, 3, 21, 6, 16, 21, 6, 15, 21, 2, 6, 21, 2, 11, 23, 16, 3, 20, 16, 11]
[22, 22, 22, 21, 17, 19, 21, 17, 5, 21, 17, 12, 21, 5, 12, 21, 19, 5, 23, 17, 12, 20, 17, 19]
[22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 21, 10, 17, 21, 10, 4, 21, 4, 7, 23, 10, 7, 20, 10]
new graph generation procedure that more efficiently generates graphs with higher lookahead: first generate a "fork" graph, where from the start vertex, there are two paths of length k, and the goal is at the end of one of the two paths. then for the remaining vertices, select parent vertices randomly from the existing vertices in the graph. k is the lookahead, so if lookahead == 1, the goal is connected to the start vertex (1-hop). if the lookahead == 0, there is only one valid edge from the start vertex.
generating 100k examples, the histogram of lookaheads looks like:
0:0.20, 1:0.26, 2:0.27, 3:0.27
[22, 21, 7, 1, 21, 6, 8, 21, 6, 7, 21, 1, 12, 21, 8, 14, 21, 14, 18, 23, 6, 18, 20, 6]
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]
[22, 21, 5, 19, 21, 11, 5, 21, 10, 3, 21, 4, 10, 21, 9, 4, 21, 9, 11, 23, 9, 3, 20, 9]
9 -> 4 -> 10 -> 3
9 -> 11 -> 5 -> 19
1st attention layer is sparse but the significant elements don't follow any discernable pattern.
2nd attention layer has a strong positive diagonal in the top left 24x24 portion of the A matrix. this is a matching operation
- the attention matrix is copying row 9 into row 21 (corresponding to the edge 10->3)
- after the attention layer, the intermediate result is in position 12 and is sparse, with high activation at both 4 and 10 (though slightly more at 10)
3rd attention layer has a strong positive diagonal immediately below the main diagonal in the bottom-right portion of the A matrix. this is a decrement position operation
- for the input, at the beginning of this layer, the intermediate result is stored in position 12 and is sparse, with high activation at both 4 and 10 (though slightly more at 10)
4th attention layer has weak positive values in the diagonals immediately above and below the main diagonal in the bottom-right 24x24 portion of the A matrix. some of these values seem to be redundant with the 3rd attention layer, where it decrements position by one (at indices 51,59,60,62,63,66). the values in the other off-diagonal corresponds to indices where position is incremented by one (at indices 61 and 64). one index (68) is a decrement operation but is not redundant in the 3rd attention layer.
- for the input, at the beginning of this layer, the intermediate result is stored in position 21 with high activation at 3 (the goal) and slightly less high at 10. row 12 has high activation at 10 and slightly less high at 4
- the attention matrix copies mostly from 12 but also a bit from 11,17,21, and a even less from 8,9.
- after the attention layer, the intermediate result is in position 21 and is sparse, with high activation at 10 (and slightly less high at 3)
5th attention layer has weak positive values in the diagonals immediately above and below the main diagonal in the bottom-right 24x24 portion of the matrix, indicating that it is incrementing or decrementing positions (depending on the index)
- for the input, at the beginning of this layer, row 12 contains
- the activation matrix copies row 15 into row 15, and row 12 into row 21 (corresponding to the edge 4->10)
- why is row 12 selected to copy into 21? the dot product of (input'[21,:]*A) and input'[12,:] is particularly large at indices 56,59,60. row 12 of the input has high activation at 4,10,41,55,56,59,60. why is (input'[21,:]*A) large at 56,59,60? (input'[21,:]*A) is large at 56 because A[56,56] and A[57,56] are large. (input'[21,:]*A) is large at 59 because A[60,59] is large. (input'[21,:]*A) is large at 60 because A[60,60] is large. so this layer is doing a backwards step.
- after the attention layer, row 21 has high activation at 10; row 15 has high activation at 4 (and slightly less high activation at 9)
6th attention layer has a weak negative diagonal in the top left 24x24 portion of the A matrix. this is a matching operation where the matched tokens are represented by having small magnitudes (i.e. negative/inverse encoding). it turns out the bottom-right 24x24 portion of the A matrix contains very weak positive diagonal immediately below the main diagonal, indicating that this matrix is doing a decrement position operation but only for some indices.
- row 21 has high activation at 10; row 15 has high activation at 4 (and slightly less high activation at 9)
- the attention matrix copies from row 15 to row 21 (corresponding to the 9->4 edge)
- after the attention layer, row 21 has high activation at 4
- in the FF layer, for the parts of the input that are positive after the first linear layer, we can describe the behavior of the FF layer by simply multiplying the two linear layers. doing so for this layer, we find that the FF layer has a fairly strong positive diagonal, and so it seems like an identity operation. some diagonal elements have not as large magnitude, and most off-diagonal elements have non-negligible magnitude too.
7th attention layer is sparse but the significant elements don't follow any discernable pattern.
- in the FF layer, the product of the two linear weight matrices is also similar to the identity, but there are some significant off-diagonal elements (so perhaps its better described as a permutation that is similar but not the same as the identity). notably, it seems that for the token 4, the largest contribution comes from row 33 of the input.
- for this input, the attention matrix copies from row 21 to row 21, 22, and 23. before the attention layer, row 21 has high activation at 4
- after the attention layer, row 21 and 23 has high activation at 4
8th attention layer is sparse but the significant elements don't follow any discernable pattern.
- at the beginning of the attention layer, row 23 has high activation at 4
- the computed attention matrix just copies from 21 into all other rows, whose argmax is 4 (the correct answer)
- interestingly, the elements of the Q*K^T matrix are so negative that they essentially ignore the decoder mask
can FF layers implement matching? it seems the model is using row 21 to store intermediate results before layer 8; where does it store information about intermediate vertices in other layers?
results from trace_circuit.py:
attention layer 0: identity
attention layer 1: token matching
attention layer 2: step backwards
attention layer 3: step backwards for some positions, identity for others
attention layer 4: step backwards
attention layer 5: step backwards
attention layer 6: copy from row 21 to 23
attention layer 7: copy from row 21 to 23
this seems to rely on the heuristic that the goal vertex is in position 21. what about an input where this is not the case?
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]
[21, 14, 2, 21, 14, 12, 21, 2, 12, 21, 2, 10, 21, 12, 6, 21, 12, 10, 23, 14, 10, 20, 14, 2]
attention layer 0: identity
attention layer 1: copy from row 19 to 21 (seems to be specific to token 20)
attention layer 2: step backwards
attention layer 3: identity
attention layer 4: copy from row 20 to 21 (seems to be specific to position 21)
attention layer 5: copy from row 20 to 21 (seems to be specific to position 21)
attention layer 6: copy from row 21 to 23
attention layer 7: copy from row 21 to 23
the model here is finding the goal vertex and copying it. what if the lookahead is > 1?
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]
[22, 22, 22, 21, 14, 3, 21, 14, 6, 21, 18, 14, 21, 3, 1, 21, 6, 16, 23, 18, 1, 20, 18, 14]
attention layer 0: identity
attention layer 1: token matching
attention layer 2: step backwards (this copies the representation of 1 from row 20 to 21)
attention layer 3: step backwards for some positions, identity for others
attention layer 4: step backwards
attention layer 5: step backwards for some positions, identity for others
attention layer 6: copy from 21 to 23
attention layer 7: copy from 21 to 23
interesting, so here it uses the first backwards step to copy the goal vertex representation from position 20 to 21, and then it continues to keep the intermediate result in position 21 and continue with the backwards search.
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]
[46, 45, 3, 19, 45, 18, 39, 45, 36, 15, 45, 24, 42, 45, 37, 3, 45, 37, 36, 45, 23, 32, 45, 8, 24, 45, 19, 30, 45, 15, 23, 45, 39, 40, 45, 40, 34, 45, 30, 18, 45, 32, 8, 47, 37, 34, 44, 37]
37 -> 3 -> 19 -> 30 -> 18 -> 39 -> 40 -> 34
37 -> 36 -> 15 -> 23 -> 32 -> 8 -> 24 -> 42
last layer (attention layer 4) copies row 32 to 47, but row 47 has strong activation of token 3 before and after this layer. row 32 has strong activation at 3,36.
attention layer 3 copies 2,18,26,35,38,46 into row 47.
before attention layer 3, row 47 has high activation at (28),(58),(85),(91),95
row 2 has high activation at 3,63,72,84,91 -> positions of 3,24,34,47
row 18 has high activation at 36,37,45,71,72,91 -> positions of 8,24,47
row 26 has high activation at 3,45,63,72,84,91 -> positions of 3,24,34,47
row 35 has high activation at 18,39,40,45,48,72,76,80,83 -> positions of 46,24,45,39,40 (seems to be encoding vertices reachable to 40)
row 38 has high activation at 3,58,71,74,91 -> positions of 45,8,19,47
row 46 has high activation at 3,45,62,63,71,72 -> positions of 37,3,8,24
attention layer 1:
copies row 26 into 38 with weight 0.995 -> step backward
copies rows 38,43 into row 5 with weights 0.83,0.14 respectively -> step backward (mostly)
copies row 2 into 26 with weight 0.873 -> step backward
copies row 14 into 2 with weight 0.467 -> step backward (mostly)
copies row 32 into 5 with weight 0.928 -> step backward
copies row 35 into 45 with weight 0.800 -> step backward
attention layer 2:
copies row 32 into 45 -> step backward
copies row 38 into 32 -> step backward
at the beginning of attention layer 3: row 45 contains vertices reachable to 34 in 3 steps, and row 38 contains vertices reachable to 30 in 3 steps (including 37 and 3), and row 32 contains vertices reachable to 39. this is possible because it's using the set merge algorithm for the first 3 layers (so the number of reachable vertices goes from 1, 2, to 4). in attention layer 3, row 38 is copied into 0,47. and row 45 is copied into 0,11,12,41,42. why is row 38 copied into 47 here? because:
Attention layer 3 is copying row 38 into row 47 with weight 0.10488934069871902 because:
Row 38 at index 58 has value 1.3306366205215454
Row 47 at index 46 has value -3.5128743648529053, and A[46,58]=-0.4384665787220001
Row 47 at index 88 has value -0.955220639705658, and A[88,58]=-0.7076801061630249
Row 38 at index 91 has value 1.253493309020996
Row 47 at index 46 has value -3.5128743648529053, and A[46,91]=-0.44504329562187195
Row 47 at index 85 has value 0.5910719633102417, and A[85,91]=1.015777587890625
Row 47 at index 88 has value -0.955220639705658, and A[88,91]=-0.8286808729171753
Row 47 at index 96 has value 1.0, and A[96,91]=0.7197601199150085
-> this looks like a heuristic.
example of lookahead = 6:
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]
[46, 46, 46, 46, 46, 46, 46, 45, 31, 39, 45, 42, 4, 45, 21, 7, 45, 19, 20, 45, 13, 22, 45, 7, 42, 45, 20, 21, 45, 17, 19, 45, 17, 31, 45, 10, 14, 45, 39, 10, 45, 14, 13, 47, 17, 4, 44, 17]
17 -> 31 -> 39 -> 10 -> 14 -> 13 -> 22
17 -> 19 -> 20 -> 21 -> 7 -> 42 -> 4
model correctly predicts 19.
attention layer 4 copies row 45 to 47. (row 45 is the goal vertex)
attention layer 3 copies rows 38,39 into 45. (these rows correspond to the edge 39 -> 10) row 38 has high activation for 31 and 39 which are reachable from 39 with 1 step, but also has high activation at 20 (which is on the other branch). row 39 has high activation for 10,14,19,20,31 (31 is 2 steps behind 10, 14 is 1 step ahead of 10, but 19 and 20 are on the other branch)
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]
[62, 62, 62, 62, 62, 61, 15, 8, 61, 11, 18, 61, 9, 5, 61, 19, 14, 61, 19, 17, 61, 1, 11, 61, 6, 7, 61, 10, 3, 61, 2, 1, 61, 13, 10, 61, 12, 4, 61, 17, 16, 61, 7, 12, 61, 14, 2, 61, 3, 9, 61, 16, 15, 61, 18, 6, 61, 8, 13, 63, 19, 4, 60, 19]
19 -> 14 -> 2 -> 1 -> 11 -> 18 -> 6 -> 7 -> 12 -> 4
19 -> 17 -> 16 -> 15 -> 8 -> 13 -> 10 -> 3 -> 9 -> 5
layer 4 copies rows 24,25,42,54,55 into 63 (the output row)
these correspond to tokens 6,7,7,18,6 respectively, which are all on the correct path (and importantly, in the latter portion of the path)
but the input to layer 4 already contains 14 (the correct answer) in row 63
layer 3 copies rows 16,19,31,39 (and some other rows with smaller weight) into 63
only 16 and 31 correspond to vertices along the correct path. 19 and 39 correspond to the vertex 17, which is on the wrong path
row 16 has strong activation at tokens 14,19
row 31 has strong activation at tokens 1,2,14
the reason why rows 16,31 are copied into 63 is because they have strong activation at 191
the input to this layer doesn't have strong activation at 14 in the last row
layer 2 copies row 62 into 16, and row 45 into 31
interestingly, if we change only the goal vertex in the input from 4 to 5, the input to attention layer 4 at row 63 still has strong activation at 14 (which is incorrect, since the correct output now is 17).
layer 4 copies rows 27,28,33,48 into 63
the reason for this is because 189 is small at all these rows
these correspond to tokens 10,3,13,3 respectively, which are all on the correct path (and importantly, in the latter portion of the path)
layer 3 copies 61 into 5,8,11,14,..., and 6,51,52 into 27, and 6,45,63 into 28, and 40,51 into 33, and 6 into 48
the input to layer 3 has many rows where 189 is small (e.g. 5,8,11,14,... mostly corresponding to placeholder tokens)
the reason 61 was copied was that A[187,62] is large
the reason why 6,40,45,51,52,63 are copied is because they have large activation at 191
following the pathway where 17 is large, they seem to be localized to the beginning of the fork 19 -> 17 (only activating vertices 19,17,16,15). so then how does the network know that this path eventually leads to the goal 5?
-> well in attention layer 4, rows 27,28,33,48 have large activation at 17. they are copied because they have small activation at 189. the rows correspond to vertices 10,3,13 which are the on the middle-to-latter portion of the correct path to the goal. why is 17 copied into these rows? why are these the only rows with low activation at 189 at the input to attention layer 4? for example, why do the rows corresponding to vertices 6,7,18 (on the other side of the fork from 10,3,13) not have low activation at 189? these rows are 24,55,25,42,10,54.
-> FF layer 3 actually causes rows 24,55,25,42,10,54 to have high activation at 189 (specifically its second linear layer). this FF layer seems to copy activation from element 79 into 189, and copy low activation at 62 into high activation at 189. before FF layer 3, rows 27,28,33,48 are not the only ones with low activation at 189. it seems that it is changing other rows to have higher activation at 189.
- doing perturbation analysis on FF layer 3, increasing element 79 to 2.5 causes the largest increase in 189, which suggests that if the input at 79 is low, then 189 will be low. but this is not sufficient, as many other rows also have low activation at 79. decreasing element 60 to -2.5 causes the largest increase in 189, which suggests that if the input at 60 is high, then 189 will be low. but again this is not sufficient, as many other rows have high activation at 60. but these two effects together almost perfectly produces low activation at 189. specifically, the input rows where activation 79 is less than -1.0 and activation 60 is greater than 1.3 are: 27,28,33,48,63. row 63 element 79 is -4.8, and element 60 is 1.7, but then why does it have low activation at 189 after FF layer 3? well perturbing the input row 63 at element 94 to have value -3.5 causes a large decrease in the activation at 189. and so a large value at 94 will cause a larger activation at 189. in fact, row 63 has the largest activation at 94.
- so it seems FF layer 3 is performing a set intersection operation: the input rows with activation at 79 < -1.0, activation at 60 > 1.3, and activation at 94 < 0.7 will become the output rows with low activation at 189. the question how is, what are these input sets and where do they come from? see further below for deeper exploration into this.
-> attention layer 3 copies 17 from rows 6,51,52,40,19,39 into 27,28,33,48 due to the high activation at 191 at rows 6,51,52,40,19,39 (these rows correspond to the vertices 15,16,17 which are in the beginning part of the correct path), and low activation at 189 in rows 27,28,33,48.
- attention layer 3 is also copying row 61 into 24,55,25,42,10,54 (and every other row except 12,13,27,28,33,48,49,63). why? it seems to be due to multiple conditions. for rows 24,55,25 (and others) the activation at 145 in these rows is large, and the activation at 45 in row 61 is large.
-> is FF layer 2 altering the representation at 191? the activation at 191 in row 51 (for example) is increased quite a lot at FF layer 2. but even before, there is moderate activation at 191.
how do we "undo" the transformation at the FF layer?
what is the derivative of [A2^T]_i * f(A1^T * x + b1) + b2_i w.r.t. x?
d/dx([A2^T]_i * f(A1^T * x + b1) + b2_i)
= [A2^T]_i * d/dx(f(u)) where u = A1^T * x + b1
= [A2^T]_i * df(u)/du * du/dx
= [A2^T]_i * df(u)/du * A1^T
f : R^n -> R^n, where f(u)_i = u_i if u_i > 0, else 0. df(u)/du is a matrix where [df(u)/du]_ij = df(u)_i/du_j = 1 if i = j and u_i > 0, else [df(u)/du]_ij = 0
i think something like this:
torch.matmul(torch.matmul(ff_parameters[2][1][:,51], torch.diag(1.0 * (self.model.transformers[2].ff[0](self.model.transformers[2].ln_ff(ff_inputs[2][51,:])) > 0.0))), ff_parameters[2][0].T)
which can be approximated by the difference formula:
diff = (self.model.transformers[2].ff(self.model.transformers[2].ln_ff(ff_inputs[2][51,:] + 1.0e-4*torch.eye(192))) - self.model.transformers[2].ff(self.model.transformers[2].ln_ff(ff_inputs[2][51,:])))[:,191] / 1.0e-4
another approach which might be better is a perturbation test; something like:
self.model.transformers[2].ff(self.model.transformers[2].ln_ff(ff_inputs[2][51,:].repeat(193,1).fill_diagonal_(0.0)))[:,191]
-> the high activation at 191 of rows 6,51,52,40,19,39 at the beginning of attention layer 3 is due to:
- row 51 has high activation at 60 and low activation at 188,189
- rows 6,52,40,19,39 has high activation at 191
-> attention layer 3 copies row 39 into 6 (the representation of vertex 17 is copied into that of vertex 15 due to the backwards set union step)
- copies row 39 into 52 (again the representation of vertex 17 is copied into that of vertex 15)
- copies row 19 into 40 (the representation of vertex 17 is copied into that of vertex 16)
- copies row 62 into 19 (the representation of the path prefix token into that of vertex 17) because 19 has low activation at 79 and 62 has high activation at 60 (the path prefix token)
- copies row 62 into 39 (the representation of the path prefix token into that of vertex 17) because 39 has low activation at 79 and 62 has high activation at 60 (the path prefix token)
-> attention layer 2 copies row 57 into 27 (the representation of vertex 8 into that of vertex 10 due to the backwards set union step)
- copies row 33 into 28 (the representation of vertex 13 into that of vertex 3)
- copies row 6 into 33 (the representation of vertex 15 into that of vertex 13)
- copies row 33 into 48 (the representation of vertex 13 into that of vertex 3)
- at this point, rows 27,28,33,48,57,33 have low activation at 189 (along with many other rows)
-> attention layer 1 copies row 51 into 6 (the representation of vertex 16 into vertex 15 due to the backwards step)
- copies row 33 into 27 (the representation of vertex 13 into that of vertex 10)
- copies row 6 into 57 (the representation of vertex 15 into that of vertex 8)
- copies row 57 into 33 (the representation of vertex 8 into that of vertex 13)
- copies row 27 into 28 (the representation of vertex 10 into that of vertex 3)
- copies row 27 into 48 (the representation of vertex 10 into that of vertex 3)
- at this point, rows 27,28,33,48,57,33,6,51 have low activation at 189 (along with many other rows)
-> FF layer 0 seems to provide all indices with low activation at 189 via the bias term of the second linear layer
-> going back to how FF layer 3 is performing a set intersection operation: the input rows with activation at 79 < -1.0, activation at 60 > 1.3, and activation at 94 < 0.7 will become the output rows with low activation at 189. the question how is, what are these input sets and where do they come from?
- we can follow rows 27,28,33,48. attention layer 3 performs the following copies: 6,51,52 -> 27, 6,45,63 -> 28, 40,51 -> 33, 6 -> 48. these copies are performed because rows 27,28,33,48 have low activation at 189, whereas the source rows have high activation at 191. the tokens with high activation at 191 are the goal vertex, the source vertex, and all vertices reachable from the source vertex with 3 hops (along both paths). the tokens with activation at 79 < -1.0 is the same set, except without the goal vertex. the tokens with activation at 60 > 1.3 are 19,10,3,13, plus a bunch of placeholder tokens. it seems the set of vertices with low activation at 189 is fairly broad, spanning both paths:
- the set of vertices with activation at 189 < -1.86: 19,2,16,1,11,8,18,13,6,10,7,3,9,4,5
- the set of vertices with activation at 189 < -2.25: 3,4,9,11,18,19 (it's possible the subsequent computation is dominated by the signal at row 28, corresponding to vertex 3)
- but interestingly, other vertices in this set seem to copy exclusively from row 61 in this attention layer. the last three vertices in the other path are 6,7,12. 6 appears in rows 24,55; 7 appears in 25,42; 12 appears in 36,43. all of these rows copy exclusively from row 61. why? for rows 24,55,25,36,43, this is due to high activation at 145 at these rows and high activation at 45 at row 61. the rows with activation at 145 > 1.2 are the start vertex and the last 4 vertices on the wrong path, except the very last vertex.
- FF layer 2 doesn't seem to change the representation of element 145 across rows. looking at rows 28,33,48, attention layer 2 copies row 33 into 28; 6 into 33; and 33 into 48. rows 33,6 correspond to vertices 13,15.
-> how different does this look when we switch the goal vertex to the other path? it seems the computation from this point to the end is the same. the vertices with large activation at 60 in the input to FF layer 3 are in rows 24,25,42,54,55 which correspond to vertices near the end of the correct path (as well as the placeholder tokens). attention layer 3 copies row 61 into rows 27,28,33,48 (corresponding to the vertices near the end of the other, now incorrect path).
- in attention layer 3, row 24 copies from rows 21,30,31,63
- row 25 copies from 21
- row 42 copies from 21,54
- row 54 copies from 45,46,61
- row 55 copies from 30
- these rows correspond to vertices near the beginning of the current path (plus the goal vertex at 61)
- for the following rows, the source is determined by a decrement operation: 25,54. the input rows with large activation at 60 are 21,30,45, which correspond to the first three vertices in the correct path, not including the start vertex.
- the rows 27,28,33,48 (corresponding to the latter vertices on the other path) are copied from row 61. the reason for this is because those rows have very low activation at 189. but note that rows 21,30,31,54,45,46 have similarly low activation at 189.
- FF layer 2 seems not to change the representation much for element 60 at rows 21,30,45.
- attention layer 2 performs the following copies: row 21 copies from 45; row 30 copies from 16; and row 45 copies from 62. 45,16,62 corresponds to vertex 14 (the first vertex along the correct path not including the start vertex) and the path prefix token. the rows with high activation at 60 correspond to most placeholder tokens, and vertices 14,17,19, which are the first vertices along each path plus the start vertex.
- FF layer 1 also seems not to change the representation much for element 60 at rows 45,16,62.
- attention layer 1 copies row 15 into rows 45,16; and copies many rows into 62 (all corresponding to vertices randomly spread across both paths). the input rows to attention layer 1 with high activation at 16 include some special tokens (including both the query prefix token and the path prefix token) as well as rows corresponding to vertex 19.
- FF layer 0 seems to be the source of high activation at 60 for many rows, as only row 60 has high activation at 60 before this layer. from perturbation analysis, setting elements 60,61 to high activation (1.8) causes element 60 to have much lower activation in the output. so very low activation at both 60,61 causes high activation at 60 in the output of this FF layer.
- attention layer 0 copies rows 60,62,63 into row 15, which correspond to the start vertex, the query prefix token, and path prefix token.
- interestingly, in attention layer 3, the rows with small activation at 145 is the same regardless of the goal vertex. it always seems to encode one path but not the other. but 145 isn't used when the goal vertex is 4. rather it seems to rely on the activation at 188 being low (in addition to the activation at 189 being low). in fact, the set of vertices with low activation at 188 is almost exactly the correct path. however, this is not different when the goal is switched (it encodes the incorrect path instead). is the difference in 61's representation of 45? YES
- FF layer 2 seems to alter the representation at row 61, and its not clear exactly what its doing. it seems to depend non-trivially on multiple elements of the input.
- attention layer 2 copies row 42 into 61. after attention layer 2, row 61 stores all vertices reachable from the goal vertex within 3 steps.
so the reconstructed algorithm is as follows:
attention layer 0 does matching (positive for non-placeholder tokens and negative for placeholder tokens)
attention layer 1 and 2 do backwards steps
attention layer 3 copies vertices reachable from the starting vertex (i.e. with large activation at 191) into rows 61 and 63
[44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 43, 15, 34, 43, 30, 9, 43, 14, 22, 43, 8, 13, 43, 8, 2, 43, 26, 1, 43, 1, 14, 43, 36, 7, 43, 22, 4, 43, 22, 2, 43, 34, 26, 43, 34, 25, 43, 28, 30, 43, 16, 3, 43, 16, 32, 43, 13, 33, 43, 12, 15, 43, 25, 21, 43, 9, 36, 43, 3, 12, 43, 32, 8, 43, 33, 28, 45, 16, 4, 42, 16]
goal vertex is 4
after the first layer, the goal vertex position should contain activation for: token value 4, position 187 (goal vertex position), and position 145 (position of other '4')
all positions containing '22' should contain activation for: token value 22, position 127 (position of first '22'), position 144 (position of second '22'), and 147 (position of other '22')
16 -> 3 -> 12 -> 15 -> 34 -> 26 -> 1 -> 14 -> 22 -> 4
layer 1 does token matching
layer 2 does the 1st backwards step: x[65] is copied into x[-3]
layer 3 does the 2nd backwards step: x[64] is copied into x[-3]
layer 4 does the 3rd backwards step: x[73] is copied into x[-3]
the goal isnt copied into x[-3] for some reason (it gets to 8 reachable vertices but the next backwards step seems incorrect)
the attention layer copies from positions 73 and 76 (containing tokens 26 and 1 respectively). instead, it should be copying from position 115 (corresponding to the source vertex of the edge 3 -> 12).
we fixed this by adding a diagonal term to the k_proj matrix at -2,-2 (matching the position of the PATH_PREFIX_TOKEN)
we need to fix the issue of the self-activation of the last token being too high (it may predict 16 instead of the correct answer which is 3). maybe try reducing the activations of all tokens in the first layer?
we fixed this by tweaking the FF-layer. by changing k in the function f(x) = x - k*ReLU(x - 1) we can maintain information about the relative sizes of the token activations within each embedding.
analyzing the unablated max_input_length = 90 model with 6 layers, max train lookahead = 10, seed = 2:
looking at the input:
[31, 31, 31, 31, 31, 31, 31, 30, 7, 23, 30, 9, 22, 30, 6, 4, 30, 6, 10, 30, 25, 19, 30, 17, 9, 30, 17, 16, 30, 1, 14, 30, 11, 21, 30, 26, 1, 30, 12, 11, 30, 14, 6, 30, 15, 25, 30, 4, 17, 30, 24, 28, 30, 19, 8, 30, 27, 26, 30, 27, 12, 30, 27, 5, 30, 22, 24, 30, 8, 3, 30, 18, 15, 30, 3, 7, 30, 3, 10, 30, 3, 2, 30, 21, 18, 32, 27, 28, 29, 27]
27 -> 26 -> 1 -> 14 -> 6 -> 4 -> 17 -> 9 -> 22 -> 24 -> 28
27 -> 12 -> 11 -> 21 -> 18 -> 15 -> 25 -> 19 -> 8 -> 3 -> 7 -> 23
27 -> 5
6 -> 10
17 -> 16
3 -> 10
3 -> 2
attention layer 5 copies row 35 into row 89 with weight 0.7547, and copies row 86 into row 89 with weight 0.0948 (it also copies row 38 into 89 with weight 0.1346 but row 38 has very low activation at 26). rows 35 and 86 of the inputs to attention layer 5 have large activation at 26 but low activation at row 38 and 89.
before FF layer 4, rows 35 and 86 still have high activation at 26 (along with high activation at non-token elements). in fact, row 86 also contains high activation at 5, 12, and 26 (excluding non-token activations), which are the valid next steps from the current vertex.
before attention layer 4, row 86 has higher activation at 5 than 26. in fact, row 86 has highest activation at 5, 12, and 26 (excluding non-token activations), which are the valid next steps from the current vertex. row 35 has high activation at 26.
before FF layer 3, row 86 still has high activations at 5, 12, 26. row 35 has high activation at 26.
before attention layer 3, row 86 still has high activations at 5, 12, 26. row 35 has high activation at 26.
before FF layer 2, row 86 still has high activations at 5, 12, 26. row 35 has high activation at 26.
before attention layer 2, row 35 still has high activation at 26. however, row 86 no longer has high activation at 5, 12, 26. the linear layer after attention layer 2 significantly increases the activation in row 86 at element 26. what causes such a large increase in activation? the bias term at 26 is -0.0300 so this is not the cause. the 26th row of the weight matrix seems to select large activations at 135 and small activations at 88.
attention layer 2 copies rows 56, 59, and 62 into row 86, with weights 0.2018, 0.1861, 0.1862, respectively. rows 59 and 62 have strong negative activation at 88.
how is the high activation at row 35 computed before the input to attention layer 2? its because the input token at position 35 is 26. so the more important question is why is row 35 selected by attention layer 5? because row 35 has strong negative activation at 138 and row 89 has strong positive activation at 138. (row 35, element 138 is -1.5198, whereas row 38, element 138 is -1.6055) -> element 138 is not sufficient since both rows 35 and 38 have large negative values.
at element 40: row 35 has -0.5226, row 38 has 0.0214
at element 113: row 35 has 0.3868, row 38 has -0.6394
at element 49: row 35 has -0.0824, row 38 has 0.3710
-> from a perturbation analysis, its neither the representation at 35 or 38 that causes the difference in behavior. rather, its due to row 89. but which part? there seem to be a handful of important indices: 40, 50, 135, 137, and (maybe not as important but still important) 55, 36, 110, 64, 120, 30, 54, 57
more specifically, row 89, elements 40, 55, and 57 have large negative value.
before FF layer 4, row 89 still has large negative activations at 40, 55, 57
but attention layer 4 makes the activations at 40 and 55 more negative. the attention layer copies row 23 into row 89.
attention layer 3 copies rows 11, 23, 29, 32, 53 into row 23. it also copies rows 29, 35, 38, 53, 65, 71 into row 89.
attention layer 2 copies rows 56, 59 into 89. (token matching?) it also copies many rows into 23. it copies rows 14, 23, 47 into 11. it copies rows 32, 50 into row 29. it copies rows 32, 38, 47 into row 32.
attention layer 1 copies 87 into EDGE_PREFIX_TOKEN and also into many target vertices and some source vertices. it also copies row 83 into rows 47, 83. (row 47 contains the token 24 which is the vertex immediately preceding the goal)
attention layer 0 copies row 87 into row 83 (and a small weight is copied into all rows with an EDGE_PREFIX_TOKEN).
if we change the goal vertex to 7, the model correctly predicts 12. how?
attention layer 5 copies rows 38 and 86 into row 89, with weights 0.6330 and 0.3143 respectively. the inputs to attention layer 5 at row 86 also have high activation at 5, 12, 26. the input at row 38 has high activation at 12.
...
before FF layer 2, row 86 still has high activations at 5, 12, 26. row 38 still has high activation at 12.
how is the high activation at row 38 computed before the input to attention layer 2? its because the input token at position 38 is 12. so the more important question is why is row 38 selected by attention layer 5? because row 38 has strong negative activation at 138 and row 89 has strong positive activation at 138. (row 35, element 138 is -1.4799, whereas row 38, element 138 is -1.5498) -> element 138 is not sufficient since both rows 35 and 38 have large negative values.
at element 40: row 35 has -0.5549, row 38 has -0.0280
at element 113: row 35 has 0.4127, row 38 has -0.6588
at element 49: row 35 has -0.1794, row 38 has 0.3273
attention layer 1 copies row 87 into EDGE_PREFIX_TOKEN and also into many target vertices (3, 6, 9, 12, 18, 21, 27, 30, 33, 36, 39, 51, 54, 60, 72, 75, 78, 81) it also copies row 83 into rows 74, 83. (row 74 contains the token 3 which is the vertex immediately preceding the goal)
attention layer 0 copies row 87 into row 83 (and a small weight is copied into all rows with an EDGE_PREFIX_TOKEN).
need to make the A-matrix analysis more robust. suppose the embedding at position i is x_i. we want to find out why i attends to j at a given layer. so why is x_i * A * x_j^T is larger than x_i * A * x_k^T for k != j? consider decomposing each x_i into y_i + z_i such that the y_i component is not useful to make x_i * A * x_j^T larger than x_i * A * x_k^T, whereas the z_i component is useful.
x_i * A * x_j^T = x_i * A * (y_j + z_j)^T
if we simplify this into the sparse setting, suppose z_j is one-hot. then to compute the non-zero index of z_j, we can compute [x_i * A]_l * x_jl for every l, and [x_i * A]_l * x_kl for every l. then we can find the l such that [x_i * A]_l * x_jl - [x_i * A]_l * x_kl = [x_i * A]_l * [x_j - x_k]_l is maximal. thus, l is the index of x_j that makes x_i * A * x_j^T larger than x_i * A * x_k^T, and A[:,l] helps to select this index.
but more generally, we could also look at every l where [x_i * A]_l * [x_j - x_k]_l > 0, which contributes to help make x_i * A * x_j^T > x_i * A * x_k^T. the relative magnitudes of [x_i * A]_l * [x_j - x_k]_l can vary (some l's can be more important than others), and we can measure the relative importance by looking at their relative magnitudes (maybe normalize?).
however, it may be the case that values of l such that [x_i * A]_l * [x_j - x_k]_l is very positive are cancelled out by values of l such that [x_i * A]_l * [x_j - x_k]_l is very negative, so does it make sense to only look at values of l where [x_i * A]_l * [x_j - x_k]_l > 0?
looking at the above example, we find that in attention layer 5, row 35 is copied into row 89, and the reason is because [x_i * A]_l * [x_j - x_k]_l is large when i=89, j=35, l=138, for all values of k except 35 and 38 (and also 5 and 86). so one important component of row 35 that facilitates its selection in this attention layer is that the element at 138 is very negative (this is probably encoding some important information; especially since input[35] is 26 and input[38] is 12, which are two of the valid next steps from the current vertex). but why is row 35 copied and not row 38?
this is a lot less clear, there doesn't seem to be one obvious element that's much larger than others. the largest contributor is element 40. to find out why, we have to consider why [x_i * A]_40 * [x_j - x_k]_40 is so large. well x_{35,40} is -0.5226 and x_{38,40} is 0.0214, and [x_89 * A]_40 is -9.8504 (why is this so negative? well because x_{89,36} = 0.8297 and A_{36,40} = -2.4941). so we would want to keep track of the negative activation at 40 in row 35.
the next largest contributor is element 125. x_{35,125} is -0.0072, x_{38,125} is 0.6846, and [x_89 * A]_125 is -5.4847 (TODO: why is this so negative?).
so perhaps the correct way to think about this is that there are certain values of x_35 that need to be as negative or as positive as possible in order to maximize the value of x_89 * A * x_35^T, and we can keep track of these "directions" such as in: element 125 should be as negative as possible, element 40 should be as negative as possible, element 138 should be as negative as possible.
so in this view, each value of x_35 is assigned either + or - indicating whether increasing that element's value will increase or decrease the attention weight on 35. propagating this backward, we encounter the layer normalization, but this does not actually change the signs that we're keeping track of. before this step, we have the FF layer. the FF layer, however, could negate some of these elements,
maybe we should compute the gradient of the cross entropy of the final output and backpropagate it?
from perturbing the goal from 28 to 7 (the other side of the fork), we note that x_{35,40} and x_{38,40} don't change much, but x_{89,36} changes quite a bit (it's now -0.1181).
going back to when the goal is 28, we see that x_{89,36} is large even before FF layer 4. attention layer 4 seems to cause x_{89,36} to become large, as the input to this layer has small x_{89,36}. attention layer 4 copies row 23 into row 89 with weight 0.9760 (NOTE: position 23 of the input corresponds to the source vertex of the edge 17 -> 9). why does this copy occur? well [x_89 * A]_77 * x_{23,77} is large (22.3977), and no other row has as large of a contribution from element 77 (i.e. [x_89 * A]_77 * x_{j,77} is small for j != 23). why is this value large, well because x_{23,77} = 0.4133, [x_89 * A]_77 = 63.8712. why is this large? because x_{89,138} = 5.6221 and A_{138,77} = 10.8522.
-> also, where does the high activation at x_{89,36} come from? it has to be from the linear layer at the end of attention layer 4. inspecting the product of [P_L * P_V]_36 and the input x_23, we find that the elements at 47 and 95 contribute the most.
perturbation analysis, how do these activations change? x_{89,138} is still large, but x_{23,77} = -0.8337. similarly, x_{23,47} and x_{23,95} don't change much (doing linear(proj_v(x_23)) on the perturbed input leads to large activation at 36 as well). attention layer 4 instead copies rows 20 and 53 into row 89 (corresponding to the source vertex of the edge 25 -> 29 and the source of the edge 19 -> 8, respectively).
going back to goal = 28, attention layer 3 copies multiple rows into row 23: 11, 23, 29, 32, 53 (and many more with smaller, but not that much smaller, weight). these correspond to the source of 9 -> 22, the source of 17 -> 9, the source of 1 -> 14, the source of 11 -> 21 (wrong path), the source of 19 -> 8 (wrong path). where is x_{23,77} copied from? it seems to be copied from rows 11, 23, 29, and 53, with strongest contributions from 11, 29, and 53. what are the source values of x_11 and x29 that cause high activation at 77? the contribution from row 11 to x_{23,77} is almost entirely due to high activation at x_{11,1}. similarly with row 29, its due to high activation at x_{29,1} (though inspecting these values, they are quite modest. the size of the contribution is mainly due to very large weight, 23.7961, at [P_V^T * P_L^T]_{1,77}).
why is x_{11,1} and x_{29,1} (modestly) positive? the corresponding elements before FF layer 2 are also large. x_{29,1} is large because the value of input[29] is 1. so then why is x_{11,1} large?
attention layer 2 doesn't make much of a change to x_{11,1}. the input to FF layer 2 also has large activation at x_{11,1}. attention layer 1 also doesn't make much of a change to x_{11,1}.
FF layer 0 seems to increase the activation at x_{11,1} by some amount. this seems to be largely due to the bias term of the 2nd linear component in this FF layer.
what if we try to trace the circuit from the input?
attention layer 0:
goal=28: attention layer 0 copies row 87 into 83, which corresponds to the source vertex of the edge 21 -> 18.
it doesn't seem to copy row 89 into anything. but it copies row 86 uniformly into many rows corresponding to source vertices, except for 56, 59, 62, 83, which correspond to the edges 27 -> 26, 27 -> 12, 27 -> 5, 21 -> 18.
goal=7: same thing
FF layer 0 seems to affect both inputs similarly.
attention layer 1:
goal=28: attention layer 1 copies row 87 into almost every row uniformly, except with smaller weights on rows 83 and 86. it also copies row 83 into row 47 with weight 0.3242 and into row 83 with weight 0.2294 and rows 0, 1, 2, 3, 5, 8, 29, 76, 79 with weights between 0.05 and 0.1. row 47 corresponds to the source vertex of the edge 24 -> 28.
it copies row 89 into a bunch of other rows corresponding to source vertices, but they're seemingly random. it copies row 86 into row 88. it copies row 56 into 86 and 89. it copies row 59 into 59, 86, 87, 89. it copies row 62 back into 86.
goal=7: attention layer 1 copies row 87 into almost every row uniformly, except with smaller weights on rows 74, 83, and 86. it also copies row 83 into row 74 with weight 0.3166 and into row 83 with weight 0.1984 and rows 2, 5, 14, 23, 29, 32, 38 with weights between 0.05 and 0.1. row 74 corresponds to the source vertex of the edge 3 -> 7.
it copies row 89 into a bunch of other rows corresponding to source vertices, but they're seemingly random. it copies row 86 into row 88. it copies row 56 into 86 and 89. it copies row 59 into 59, 86, 87, 89. it copies row 62 back into 86.
-> this is a backwards step from the goal
attention layer 2:
goal=28: attention layer 2 copies row 47 into many rows uniformly:
8, 11, 14, 17, 20, 23, 26, 29, 32, 35, 38, 41, 44, 50, 53, 68, 71, 74, 77, 83, 88 (these are all source vertices except 56, 59, 62, especially 65, which instead have lower weights, but not as low as non-source vertices) -> the row corresponding to a source vertex with low activation is 65, which corresponds to the source vertex of the edge 22 -> 24. this looks like a backwards step with negative encoding
it also copies row 83 into almost every row uniformly, except 86 (which is the start vertex). row 87 is not really copied anywhere (the weights are less than 0.01)
goal=7: attention layer 2 copies row 74 into many rows uniformly:
11, 14, 17, 20, 23, 29, 32, 35, 41, 44, 47, 50, 65, 71, 77, 83, 88 (these are all source vertices except 56, 59, 62, especially 68, and 74, which instead have lower weights, but not as low as non-source vertices) -> the rows corresponding to a source vertex with low activation are 68 and 74, which correspond to the source vertices of the edges 8 -> 3 and 3 -> 7
it also copies row 83 into almost every row uniformly, except 86 (which is the start vertex). row 87 is not really copied anywhere (the weights are less than 0.01)
attention layer 3:
goal 28: attention layer 3 copies row 65 into 41 (and many other rows with weight < 0.1) -> 41 corresponds to the source vertex of the edge 14 -> 6
goal 7: attention layer 3 copies row 68 into many rows uniformly, except 8, 74, 80, 87, and especially 20 and 53. 20 and 53 correspond to the edges 25 -> 19 and 19 -> 8, respectively.
another backwards step?
attention layer 4:
goal 28: attention layer 4 copies row 41 into row 35, which corresponds to the source vertex of the edge 26 -> 1 (26 is the correct answer).
goal 7: attention layer 4 copies row 20 into many rows, but rows 83, 88, and 89 have highest weight. 83 corresponds to the edge 21 -> 18, whereas 88 and 89 are the last two tokens of the input.
attention layer 5:
goal 28: row 35 is copied into row 89.
goal 7: row 83 is barely copied anywhere. row 38 is copied into row 89.
some perturbation analysis when changing the goal from 28 to/from 7:
in the inputs to layer 5, it seems like rows 35 and 38 make no difference. row 89 seems to make all the difference (is it storing the set of reachable vertices from the goal? and then layer 5 computes the intersection with the next valid vertices?).
for the inputs to layer 4, however, it seems only the representations at rows 20 and 23 causes the prediction to flip from one path to the other. when the goal is 28, layer 4 copies row 23 into 89. if the goal is 7, layer 4 copies rows 20 and 53 into 89. why are these rows selected to copy from?
when the goal is 28, row 23 is selected because row 23 has large positive activation at 77, whereas row 20 has large negative activation at 77. when the goal is 7, row 20 has positive activation at 77, and row 23 has negative activation at 77. in fact, these are the only rows where activation at 77 is positive. this is due to the fact that the A-matrix at layer 4 has strong activation at [138,77] what is activation at 77 encoding?
this large activation at 77 is traceable to before FF layer 3. it seems to encode reachability to the goal. but if this is the case, why does the last layer choose from the next valid vertices correctly?
some perturbation analysis when swapping the target vertices of the edges 6 -> 4 and 18 -> 15, so then the edges become 6 -> 15 and 18 -> 4:
in the inputs to layer 5, row 89 makes no difference. in contrast to the previous perturbation, rows 35 and 38 determine which is copied from.
layer 4 copies rows 14 and 41 into row 35, and row 71 into 38.
layer 3 copies rows 14, 29, 32, and 71 into 14 and 41.
when perturbing the target vertices of the edges 4 -> 17 and 15 -> 25, both rows 89 and 35,38 are important for determining the correct copy at the last layer.
attention layer 5 copies row 35 into 89 because element 40 is negative and element 125 is zero.
attention layer 4 copies row 14 into 35 and the negativity of element 40 comes from: element 39 is negative, element 106 is positive, and element 122 is positive.
repeating the input from above for convenience:
[31, 31, 31, 31, 31, 31, 31, 30, 7, 23, 30, 9, 22, 30, 6, 4, 30, 6, 10, 30, 25, 19, 30, 17, 9, 30, 17, 16, 30, 1, 14, 30, 11, 21, 30, 26, 1, 30, 12, 11, 30, 14, 6, 30, 15, 25, 30, 4, 17, 30, 24, 28, 30, 19, 8, 30, 27, 26, 30, 27, 12, 30, 27, 5, 30, 22, 24, 30, 8, 3, 30, 18, 15, 30, 3, 7, 30, 3, 10, 30, 3, 2, 30, 21, 18, 32, 27, 28, 29, 27]
27 -> 26 -> 1 -> 14 -> 6 -> 4 -> 17 -> 9 -> 22 -> 24 -> 28
27 -> 12 -> 11 -> 21 -> 18 -> 15 -> 25 -> 19 -> 8 -> 3 -> 7 -> 23
27 -> 5
6 -> 10
17 -> 16
3 -> 10
3 -> 2
using new contribution representation analysis, using the following to compute the magnitude of the contribution from various inputs:
torch.linalg.vector_norm(representations[3][30,:,:], dim=-1)
we see that layer 5 copies row 35 into 89 because of the high contribution from the token value at position 51 (corresponding to the 17 in the edge 4 -> 17).
layer 4 copies rows 14 and 41 into 35. row 14 has large contribution from the token value at position 51. why is 14 copied into 35? this is because row 14 has (relatively) high contribution from the token value at position 30 (corresponding to the 14 in the edge 1 -> 14), and row 35 has high contribution from the token value at position 30.
why does row 14 have large contribution from the token value at position 51? and why does row 35 has large contribution from the token value at position 30? it is unclear why row 35 has large contribution from token value at position 30, before the attention layer, the contribution is very small, but attention layer 3 seems to copy from seemingly random source rows. it's possible these other rows are storing information about graph topology that is nonspecific to the original values at those rows. or perhaps these rows are being copied everywhere, and what we care about is the diff (i.e. the rows that are being selectively copied into specific destination rows). -> row 29 corresponds to the source vertex in the edge 1 -> 14, which connects the vertex at row 35 (27) with that at row 14 (6), but the attention matrix is copying row 29 more into rows 14, 17, 35, 41, and 56. how can we detect this?
row 65 (correponding to the 22 in the edge 22 -> 24) is copied into row 35 because of large contribution from the token value at position 12 (the 22 in the edge 9 -> 22) in row 35 and the large contribution from the position embedding 176-90=86 (the start vertex special token). why does row 35 have so much information from the end of the path?
row 50 (corresponding to the 4 in the edge 4 -> 17) is copied into row 35 because of large contribution from the token value at position 15 (the 4 in the edge 6 -> 4) in row 35 and the large contribution from the token value at position 50 in row 50. so this is a forward step (!).
row 53 (corresponding to the 19 in the edge 19 -> 8) is copied into row 35 because of large contribution from the token value at position 21 (corresponding to the 19 in the edge 25 -> 19) in row 35 and the large contribution from the token value at position 53 (corresponding to the 19 in the edge 19 -> 8). so this is also a forward step.
row 29 is copied into row 35 due to a forward step.
row 14 has large contribution from the token value at position 51 because attention layer 3 copies from row 50 to row 14. and this copy is due to a backwards step.
layer 2 anti-copies row 29 into row 35 because row 29 has large contribution from the token value at position 29 (the 1 in 1 -> 14) and row 35 has large contribution from the token value at position 36 (the 1 in 26 -> 1), so this is a backwards step that is split over layers 2 and 3 (since layer 3 then copies from many different locations). this is what leads to row 35 having large contribution from the token value at position 30 at the beginning of layer 4.
how do we verify that the contributions are due to position and not the token value? consider the copy in layer 5 from row 35 into 89: this was due to high contribution from the token value at position 51 (the 14 in 1 -> 14). we can change the position of the edge 1 -> 14 and see how this affects the contribution. if we move the edge 1 -> 14 so that the 14 appears in position 48, we see that the copy from row 35 into 89 in layer 5 is due to high contribution from the token value at position 48. we can do this multiple times to get better confidence that this contribution encodes position rather than some other quantity.
-> but why would this contribution encode position and not token value? wouldn't there be identifiability issues if there are vertices of high degree? even if all vertices have degree 2 (in-degree 1 and out-degree 1), then you still can't implement a backwards/forwards step in the attention layer by using the token embeddings alone.
-> maybe it would help to train a linear probe to classify position? but note that each embedding can store multiple positions, especially after one or more backwards/forwards steps, so this becomes more similar to training a probe for reachability.
attention layer 4 copies row 14 into 35 due to high contribution from token at position 30 in both rows. so is this just a token matching step? how can we differentiate token matching from backwards/forwards steps?
why does row 14 have high contribution from token at position 30? before attention layer 3, the rows with high contribution from token at position 30 are: 11, 32, 47, 53, 65, 68, 71, 88, 8, 14, 20, 23, 29, 74, 26. (most of these are due to forwards steps: rows 11, 26, 32, 47, 53, 68, 71, 88) attention layer 2 copies row 29 (which has high contribution from row 30) into a bunch of rows, including 26, 11, 32, 47, 53, 68, 71, 88, due to row 29 having high contribution from 29 and row x has high contribution from the token at x+1 where x is the dst row. so attention layers 2 and 3 collectively copied the contribution from token at position 30 (from row 29; which corresponds to the edge 1 -> 14) into row 14 (which corresponds to the edge 6 -> 4). so attention layers 2 and 3 are collectively performing a single forwards step.
-> it may become clearer by focusing the analysis on specific contributions, one at a time.
(interestingly, if we change the input into:
27 -> 26 -> 1 -> 14 -> 6 -> 4 -> 17 -> 9 -> 22 -> 24 -> 28
27 -> 12 -> 11 -> 21 -> 18 -> 15 -> 25 -> 19 -> 8 -> 3
27 -> 5 -> 7 -> 23
6 -> 10
17 -> 16
3 -> 10
3 -> 2
and the goal is 28, the model will incorrectly predict 5 (!)
if the goal is 7, the model will incorrectly predict 26 (!))
probing results on source vertices:
(note the overall test accuracy of the network on graphs with lookahead=10 is 99%)
first we check if the model is doing a backward search from the goal vertex: (the probe doesn't include the goal)
layer 2 can decode
reachable_distance=-1 excluding start vertex with 100% accuracy, but not reachable_distance=-2
reachable_distance=-2 gets 96% (45% true positive rate, 100% true negative rate)
layer 3 can decode
reachable_distance=-1 excluding start vertex with 100% accuracy (90% true positive rate, 100% true negative rate)
reachable_distance=-2 excluding start vertex with 99% accuracy (90% true positive rate, 100% true negative rate)
but only gets 95% for reachable_distance=-3 (60% true positive rate, 100% true negative rate)
layer 4,
reachable_distance=-3 gets 97% (78% true positive rate, 100% true negative rate)
reachable_distance=-4 gets 95% (75% true positive rate, 99% true negative rate)
reachable_distance=-5 gets 91% (56% true positive rate, 99% true negative rate)
reachable_distance -8 gets % (% true positive rate, % true negative rate)
layer 5,
reachable_distance -4 gets 95% (74% true positive rate, 100% true negative rate)
reachable_distance -5 gets 93% (73% true positive rate, 99% true negative rate)
reachable_distance -6 gets 90% (65% true positive rate, 98% true negative rate)
reachable_distance -7 gets 87% (60% true positive rate, 98% true negative rate)
reachable_distance -8 gets 84% (54% true positive rate, 97% true negative rate)
reachable_distance -10 gets 75% (50% true positive rate, 93% true negative rate)
layer 6,
reachable_distance -4 gets 95% (70% true positive rate, 99% true negative rate)
reachable_distance -8 gets 83% (53% true positive rate, 97% true negative rate)
check if the model is doing a forward search from the start vertex:
layer 3, reachable_distance=+2 gets 87% (0% true positive rate, 100% true negative rate) (this predictor is just guessing 0 for all inputs) *if the probe doesn't include the start vertex*
if the probe does include the start vertex, it gets 87% accuracy (26% true positive rate, 100% true negative rate) (but this could just be identifying the start vertex and guessing 1 for only that vertex)
layer 4, reachable_distance=+2 gets 93% (42% true positive rate, 100% true negative rate)
if the probe does include the start vertex, it gets 92% accuracy (57% true positive rate, 100% true negative rate)
layer 5, reachable_distance=+2 gets 92% (57% true positive rate, 100% true negative rate)
check if the model is doing a backward search from vertices at specific absolute positions, or from the kth vertex in the correct path, the probe gets 0% true positive rate and 100% true negative rate.
probing results on source vertices:
layer 2,
reachable_distance=-1 gets 96% (0% true positive rate, 100% true negative rate)
layer 3,
reachable_distance=-1 gets 74% (79% true positive rate, 74% true negative rate)
reachable_distance=-2 gets 91% (32% true positive rate, 96% true negative rate)
what if the model is doing a forward search from the start vertex: (the probe includes the start vertex)
layer 2,
reachable_distance=+1 gets 95% (82% true positive rate, 96% true negative rate)
layer 3,
reachable_distance=+1 gets 98% (59% true positive rate, 100% true negative rate)
reachable_distance=+2 gets 92% (36% true positive rate, 99% true negative rate)
layer 4,
reachable_distance=+2 gets 88% (0% true positive rate, 100% true negative rate)
reachable_distance=+3 gets 80% (2% true positive rate, 99% true negative rate)
layer 5,
reachable_distance=+3 gets 81% (0% true positive rate, 100% true negative rate)
Gameplan for new analysis method:
1. first, for a given input and specific output representation (a n x d matrix; which contains all zeros except for a 1 in the output prediction in the last row), we need to reconstruct the attention copy tree: for each layer, how much does the input contribute to the final output prediction? and where does this contribution come from (from which rows does the previous attention layer copy this contribution)?
2. explain the attention copies
if we have a set of contribution vectors written as rows in the matrix C (which has dimension m x d, where m is the number of contributors), and we have some representation vector r, how do we find the linear combination of the rows of C that equal r? that is, what is the value of a such that a*C = r? but how is this actually useful? for a dot product in the attention mechanism, we compute: r_dst * A * r_src^T where r is the input representation to the attention layer (after layer norm). we can decompose r_dst in terms of the contributions C at the input of this operation, and similarly decompose r_dst * A to see how the A matrix has transformed the representation of r_dst.
layer 0 copies the target vertex of each edge into the source vertex. it does so by looking at the position encodings: each position that corresponds to the position of a source vertex will copy from the position that is 1 greater.
layer 1 performs a token matching step. all source vertices will copy from themselves.
layer 2 is performing an inverse backwards step, copying the representation of 11 to essentially all other tokens except the token at position 23, which corresponds to the edge 17 -> 9. this operation does not seem to be uniform at all positions (e.g. it does this inverse backwards step at positions 11, 14, and 20, but not 17).
-> but performing the contribution analysis here, we find that the most significant contributors to the negative dot product sum between 23 and 11 (corresponding to the low attention weight from 11 to 23) is due to the token embedding, and not the position embedding. how is the token embedding storing information about the position embedding here?
repeating the input from above for convenience:
[31, 31, 31, 31, 31, 31, 31, 30, 7, 23, 30, 9, 22, 30, 6, 4, 30, 6, 10, 30, 25, 19, 30, 17, 9, 30, 17, 16, 30, 1, 14, 30, 11, 21, 30, 26, 1, 30, 12, 11, 30, 14, 6, 30, 15, 25, 30, 4, 17, 30, 24, 28, 30, 19, 8, 30, 27, 26, 30, 27, 12, 30, 27, 5, 30, 22, 24, 30, 8, 3, 30, 18, 15, 30, 3, 7, 30, 3, 10, 30, 3, 2, 30, 21, 18, 32, 27, 28, 29, 27]
27 -> 26 -> 1 -> 14 -> 6 -> 4 -> 17 -> 9 -> 22 -> 24 -> 28
27 -> 12 -> 11 -> 21 -> 18 -> 15 -> 25 -> 19 -> 8 -> 3 -> 7 -> 23
27 -> 5
6 -> 10
17 -> 16
3 -> 10
3 -> 2