From 8feb42768e4f114b2cbf43d7cd92f3c3ee0e2464 Mon Sep 17 00:00:00 2001 From: asaparov Date: Tue, 26 Mar 2024 05:04:32 -0400 Subject: [PATCH] Added code to train linear probes for vertex reachability. Also added code for linear contribution analysis. --- analyze.py | 48 ++++++++++++++++++++----- generator.cpp | 55 ++++++++++++++++------------ notes.txt | 94 ++++++++++++++++++++++++++++++++++++++++++++++-- trace_circuit.py | 47 +++++++++++++++++------- train_probe.py | 65 +++++++++++++++++++++------------ 5 files changed, 240 insertions(+), 69 deletions(-) diff --git a/analyze.py b/analyze.py index 3d91be9..c5bced3 100644 --- a/analyze.py +++ b/analyze.py @@ -10,27 +10,57 @@ def perturb_vertex_ids(input, fix_index, num_examples, max_input_size): EDGE_PREFIX_TOKEN = (max_input_size-5) // 3 + 2 max_vertex_id = (max_input_size-5) // 3 + # compute the correct next vertex + graph = {} + for i in range(len(input)): + if input[i] == EDGE_PREFIX_TOKEN: + if int(input[i+1]) not in graph: + graph[int(input[i+1])] = [int(input[i+2])] + else: + graph[int(input[i+1])].append(int(input[i+2])) + useful_steps = [] + for neighbor in graph[int(input[-1])]: + stack = [neighbor] + reachable = [] + while len(stack) != 0: + current = stack.pop() + reachable.append(current) + if current not in graph: + continue + for child in graph[current]: + if child not in reachable: + stack.append(child) + if int(input[-3]) in reachable: + useful_steps.append(neighbor) + if len(useful_steps) == 0: + raise Exception('Given input has no path to goal vertex.') + elif len(useful_steps) != 1: + raise Exception('Given input has more than one next step to goal vertex.') + out = torch.empty((num_examples, input.shape[0]), dtype=torch.int64) out_labels = torch.empty((num_examples), dtype=torch.int64) out[0,:] = input edge_indices = [i for i in range(len(input)) if input[i] == EDGE_PREFIX_TOKEN] edge_count = len(edge_indices) - fixed_edge_index = next(i for i in range(len(edge_indices)) if fix_index >= edge_indices[i] and fix_index < edge_indices[i] + 3) - fixed_edge = edge_indices[fixed_edge_index] + if fix_index != None: + fixed_edge_index = next(i for i in range(len(edge_indices)) if fix_index >= edge_indices[i] and fix_index < edge_indices[i] + 3) + fixed_edge = edge_indices[fixed_edge_index] padding_size = next(i for i in range(len(input)) if input[i] != PADDING_TOKEN) out[:,:padding_size] = PADDING_TOKEN - out_labels[0] = out[0,fix_index+1] + out_labels[0] = useful_steps[0] for i in range(1, num_examples): id_map = list(range(1, max_vertex_id + 1)) shuffle(id_map) id_map = [0] + id_map - del edge_indices[fixed_edge_index] + if fix_index != None: + del edge_indices[fixed_edge_index] shuffle(edge_indices) - edge_indices.insert(fixed_edge_index, fixed_edge) + if fix_index != None: + edge_indices.insert(fixed_edge_index, fixed_edge) for j in range(len(edge_indices)): out[i,padding_size+(3*j):padding_size+(3*j)+3] = torch.LongTensor([EDGE_PREFIX_TOKEN, id_map[input[edge_indices[j]+1]], id_map[input[edge_indices[j]+2]]]) out[i,padding_size+(3*edge_count):] = torch.LongTensor([(id_map[v] if v <= max_vertex_id else v) for v in input[padding_size+(3*edge_count):]]) - out_labels[i] = out[i,fix_index+1] + out_labels[i] = id_map[useful_steps[0]] return out, out_labels def run_model(model, input, fix_index, max_input_size, num_perturbations=2**14): @@ -49,7 +79,6 @@ def run_model(model, input, fix_index, max_input_size, num_perturbations=2**14): print(padded_input) perturbed_input, perturbed_output = perturb_vertex_ids(padded_input, fix_index, 1+num_perturbations, max_input_size) predictions, _ = model(perturbed_input) - import pdb; pdb.set_trace() if len(predictions.shape) == 3: predictions = predictions[:, -1, :] perturbed_output = perturbed_output.to(device) @@ -217,7 +246,8 @@ def ideal_model(max_input_size, num_layers, hidden_dim, bidirectional, absolute_ #run_model(model, [22, 21, 5, 19, 21, 11, 5, 21, 10, 3, 21, 4, 10, 21, 9, 4, 21, 9, 11, 23, 9, 3, 20, 9], max_input_size=24) #run_model(model, [22, 22, 22, 22, 22, 22, 22, 21, 1, 2, 21, 1, 4, 21, 2, 3, 21, 4, 5, 23, 1, 3, 20, 1], max_input_size=24) #run_model(model, [46, 45, 3, 19, 45, 18, 39, 45, 36, 15, 45, 24, 42, 45, 37, 3, 45, 37, 36, 45, 23, 32, 45, 8, 24, 45, 19, 30, 45, 15, 23, 45, 39, 40, 45, 40, 34, 45, 30, 18, 45, 32, 8, 47, 37, 34, 44, 37], max_input_size=48) - #run_model(model, [43, 15, 34, 43, 30, 9, 43, 14, 22, 43, 8, 13, 43, 8, 2, 43, 26, 1, 43, 1, 14, 43, 36, 7, 43, 22, 4, 43, 22, 2, 43, 34, 26, 43, 34, 25, 43, 28, 30, 43, 16, 3, 43, 16, 32, 43, 13, 33, 43, 12, 15, 43, 25, 21, 43, 9, 36, 43, 3, 12, 43, 32, 8, 43, 33, 28, 45, 16, 4, 42, 16], fix_index=98, max_input_size=max_input_size, num_perturbations=0) + run_model(model, [31, 31, 31, 31, 31, 31, 31, 30, 7, 23, 30, 9, 22, 30, 6, 4, 30, 6, 10, 30, 25, 19, 30, 17, 9, 30, 17, 16, 30, 1, 14, 30, 11, 21, 30, 26, 1, 30, 12, 11, 30, 14, 6, 30, 15, 25, 30, 24, 28, 30, 4, 17, 30, 19, 8, 30, 27, 26, 30, 27, 12, 30, 27, 5, 30, 22, 24, 30, 8, 3, 30, 18, 15, 30, 3, 7, 30, 3, 10, 30, 3, 2, 30, 21, 18, 32, 27, 28, 29, 27], fix_index=None, max_input_size=max_input_size, num_perturbations=1000) + import pdb; pdb.set_trace() seed(training_seed) torch.manual_seed(training_seed) @@ -236,7 +266,7 @@ def get_seed(index): NUM_TEST_SAMPLES = 1000 reserved_inputs = set() print("Generating eval data...") - inputs,outputs = generate_eval_data(max_input_size, min_path_length=2, distance_from_start=1, distance_from_end=-1, lookahead_steps=13, num_paths_at_fork=None, num_samples=NUM_TEST_SAMPLES) + inputs,outputs = generate_eval_data(max_input_size, min_path_length=2, distance_from_start=1, distance_from_end=-1, lookahead_steps=10, num_paths_at_fork=None, num_samples=NUM_TEST_SAMPLES) #generator.set_seed(get_seed(1)) #inputs, outputs, _, _ = generator.generate_training_set(max_input_size, NUM_TEST_SAMPLES, training_max_lookahead, reserved_inputs, 1, False) print("Evaluating model...") diff --git a/generator.cpp b/generator.cpp index e02fff2..8446e90 100644 --- a/generator.cpp +++ b/generator.cpp @@ -575,7 +575,7 @@ py::tuple generate_training_set(const unsigned int max_input_size, const uint64_ return py::make_tuple(inputs, outputs, valid_outputs, num_collisions); } -py::tuple generate_reachable_training_set(const unsigned int max_input_size, const uint64_t dataset_size, const unsigned int lookahead, const py::object& reserved_inputs, const int distance_from_start, const int reachable_distance, const bool start_from_goal) +py::tuple generate_reachable_training_set(const unsigned int max_input_size, const uint64_t dataset_size, const unsigned int lookahead, const py::object& reserved_inputs, const int distance_from_start, const int reachable_distance, const unsigned int start_vertex_index, const bool exclude_start_vertex) { const unsigned int QUERY_PREFIX_TOKEN = (max_input_size-5) / 3 + 4; const unsigned int PADDING_TOKEN = (max_input_size-5) / 3 + 3; @@ -587,11 +587,12 @@ py::tuple generate_reachable_training_set(const unsigned int max_input_size, con size_t input_shape[2]{dataset_size, max_input_size}; size_t output_shape[2]{dataset_size, max_input_size}; py::array_t inputs(input_shape); - py::array_t outputs(output_shape); + py::array_t outputs(output_shape); auto inputs_mem = inputs.mutable_unchecked<2>(); auto outputs_mem = outputs.mutable_unchecked<2>(); py::list valid_outputs; + unsigned int max_vertex_id = (max_input_size - 5) / 3; while (num_generated < dataset_size) { array g(32); node* start; node* end; @@ -606,7 +607,7 @@ py::tuple generate_reachable_training_set(const unsigned int max_input_size, con } unsigned int num_vertices = std::min(lookahead * num_paths + 1 + randrange(0, 6), (max_input_size-5) / 3); - if (!generate_example(g, start, end, paths, num_vertices, 4, (max_input_size - 5) / 3, true, lookahead, num_paths)) { + if (!generate_example(g, start, end, paths, num_vertices, 4, max_vertex_id, true, lookahead, num_paths)) { for (node& n : g) core::free(n); for (array& a : paths) core::free(a); g.length = 0; paths.length = 0; @@ -635,26 +636,6 @@ py::tuple generate_reachable_training_set(const unsigned int max_input_size, con } shuffle(edges); - /* compute the set of reachable vertices */ - array reachable(16); - array> stack(16); - unsigned int start_vertex = start_from_goal ? end->id : start->id; - stack.add(make_pair(start_vertex, 0u)); - while (stack.length != 0) { - pair entry = stack.pop(); - unsigned int current_vertex = entry.key; - unsigned int current_distance = entry.value; - if (!reachable.contains(current_vertex)) - reachable.add(current_vertex); - if (reachable_distance > 0 && current_distance + 1 <= (unsigned int) reachable_distance) { - for (node* child : g[current_vertex].children) - stack.add(make_pair(child->id, current_distance + 1)); - } else if (reachable_distance < 0 && current_distance + 1 <= (unsigned int) -reachable_distance) { - for (node* parent : g[current_vertex].parents) - stack.add(make_pair(parent->id, current_distance + 1)); - } - } - array prefix(max_input_size); for (auto& entry : edges) { prefix[prefix.length++] = EDGE_PREFIX_TOKEN; @@ -681,6 +662,34 @@ py::tuple generate_reachable_training_set(const unsigned int max_input_size, con if (example.length > max_input_size) continue; + /* compute the set of reachable vertices */ + node** vertex_id_map = (node**) calloc(max_vertex_id + 1, sizeof(node*)); + for (unsigned int i = 0; i < g.length; i++) + vertex_id_map[g[i].id] = &g[i]; + array reachable(16); + array> stack(16); + unsigned int start_vertex; + if (example.length < start_vertex_index) + start_vertex = start->id; + else start_vertex = example[example.length - start_vertex_index]; + stack.add(make_pair(start_vertex, 0u)); + while (stack.length != 0) { + pair entry = stack.pop(); + unsigned int current_vertex = entry.key; + unsigned int current_distance = entry.value; + if (!reachable.contains(current_vertex)) + reachable.add(current_vertex); + if (reachable_distance > 0 && current_distance + 1 <= (unsigned int) reachable_distance) { + for (node* child : vertex_id_map[current_vertex]->children) + stack.add(make_pair(child->id, current_distance + 1)); + } else if (reachable_distance < 0 && current_distance + 1 <= (unsigned int) -reachable_distance) { + for (node* parent : vertex_id_map[current_vertex]->parents) + stack.add(make_pair(parent->id, current_distance + 1)); + } + } + if (exclude_start_vertex) + reachable.remove(reachable.index_of(start_vertex)); + array useful_steps(8); for (node* v : path[j-1]->children) if (has_path(v, end)) useful_steps.add(v); diff --git a/notes.txt b/notes.txt index d2d7735..17754b3 100644 --- a/notes.txt +++ b/notes.txt @@ -862,7 +862,40 @@ attention layer 4 copies row 14 into 35 and the negativity of element 40 comes f -interestingly, if we change the input into: +repeating the input from above for convenience: +[31, 31, 31, 31, 31, 31, 31, 30, 7, 23, 30, 9, 22, 30, 6, 4, 30, 6, 10, 30, 25, 19, 30, 17, 9, 30, 17, 16, 30, 1, 14, 30, 11, 21, 30, 26, 1, 30, 12, 11, 30, 14, 6, 30, 15, 25, 30, 24, 28, 30, 4, 17, 30, 19, 8, 30, 27, 26, 30, 27, 12, 30, 27, 5, 30, 22, 24, 30, 8, 3, 30, 18, 15, 30, 3, 7, 30, 3, 10, 30, 3, 2, 30, 21, 18, 32, 27, 28, 29, 27] + +27 -> 26 -> 1 -> 14 -> 6 -> 4 -> 17 -> 9 -> 22 -> 24 -> 28 +27 -> 12 -> 11 -> 21 -> 18 -> 15 -> 25 -> 19 -> 8 -> 3 -> 7 -> 23 +27 -> 5 + 6 -> 10 +17 -> 16 + 3 -> 10 + 3 -> 2 + +using new contribution representation analysis, using the following to compute the magnitude of the contribution from various inputs: + torch.linalg.vector_norm(representations[3][30,:,:], dim=1) +we see that layer 5 copies row 35 into 89 because of the high contribution from the token value at position 51 (corresponding to the 17 in the edge 4 -> 17). + +layer 4 copies rows 14 and 41 into 35. row 14 has large contribution from the token value at position 51. why is 14 copied into 35? this is because row 14 has (relatively) high contribution from the token value at position 30 (corresponding to the 14 in the edge 1 -> 14), and row 35 has high contribution from the token value at position 30. + +why does row 14 have large contribution from the token value at position 51? and why does row 35 has large contribution from the token value at position 30? it is unclear why row 35 has large contribution from token value at position 30, before the attention layer, the contribution is very small, but attention layer 3 seems to copy from seemingly random source rows. it's possible these other rows are storing information about graph topology that is nonspecific to the original values at those rows. or perhaps these rows are being copied everywhere, and what we care about is the diff (i.e. the rows that are being selectively copied into specific destination rows). -> row 29 corresponds to the source vertex in the edge 1 -> 14, which connects the vertex at row 35 (27) with that at row 14 (6), but the attention matrix is copying row 29 more into rows 14, 17, 35, 41, and 56. how can we detect this? + row 65 (correponding to the 22 in the edge 22 -> 24) is copied into row 35 because of large contribution from the token value at position 12 (the 22 in the edge 9 -> 22) in row 35 and the large contribution from the position embedding 176-90=86 (the start vertex special token). why does row 35 have so much information from the end of the path? + + row 50 (corresponding to the 4 in the edge 4 -> 17) is copied into row 35 because of large contribution from the token value at position 15 (the 4 in the edge 6 -> 4) in row 35 and the large contribution from the token value at position 50 in row 50. so this is a forward step (!). + row 53 (corresponding to the 19 in the edge 19 -> 8) is copied into row 35 because of large contribution from the token value at position 21 (corresponding to the 19 in the edge 25 -> 19) in row 35 and the large contribution from the token value at position 53 (corresponding to the 19 in the edge 19 -> 8). so this is also a forward step. + row 29 is copied into row 35 due to a forward step. +row 14 has large contribution from the token value at position 51 because attention layer 3 copies from row 50 to row 14. and this copy is due to a backwards step. + +layer 2 anti-copies row 29 into row 35 because row 29 has large contribution from the token value at position 29 (the 1 in 1 -> 14) and row 35 has large contribution from the token value at position 36 (the 1 in 26 -> 1), so this is a backwards step that is split over layers 2 and 3 (since layer 3 then copies from many different locations). this is what leads to row 35 having large contribution from the token value at position 30 at the beginning of layer 4. + +how do we verify that the contributions are due to position and not the token value? consider the copy in layer 5 from row 35 into 89: this was due to high contribution from the token value at position 51 (the 14 in 1 -> 14). we can change the position of the edge 1 -> 14 and see how this affects the contribution. if we move the edge 1 -> 14 so that the 14 appears in position 48, we see that the copy from row 35 into 89 in layer 5 is due to high contribution from the token value at position 48. we can do this multiple times to get better confidence that this contribution encodes position rather than some other quantity. + -> but why would this contribution encode position and not token value? wouldn't there be identifiability issues if there are vertices of high degree? even if all vertices have degree 2 (in-degree 1 and out-degree 1), then you still can't implement a backwards/forwards step in the attention layer by using the token embeddings alone. + -> maybe it would help to train a linear probe to classify position? but note that each embedding can store multiple positions, especially after one or more backwards/forwards steps, so this becomes more similar to training a probe for reachability. + + + +(interestingly, if we change the input into: 27 -> 26 -> 1 -> 14 -> 6 -> 4 -> 17 -> 9 -> 22 -> 24 -> 28 27 -> 12 -> 11 -> 21 -> 18 -> 15 -> 25 -> 19 -> 8 -> 3 27 -> 5 -> 7 -> 23 @@ -870,6 +903,61 @@ interestingly, if we change the input into: 17 -> 16 3 -> 10 3 -> 2 - and the goal is 28, the model will incorrectly predict 5 (!) -if the goal is 7, the model will incorrectly predict 26 (!) +if the goal is 7, the model will incorrectly predict 26 (!)) + + + +probing results on source vertices: +(note the overall test accuracy of the network on graphs with lookahead=10 is 99%) +first we check if the model is doing a backward search from the goal vertex: (the probe doesn't include the goal) +layer 2 can decode + reachable_distance=-1 excluding start vertex with 100% accuracy, but not reachable_distance=-2 + reachable_distance=-2 gets 96% (45% true positive rate, 100% true negative rate) +layer 3 can decode + reachable_distance=-1 excluding start vertex with 100% accuracy (90% true positive rate, 100% true negative rate) + reachable_distance=-2 excluding start vertex with 99% accuracy (90% true positive rate, 100% true negative rate) + but only gets 95% for reachable_distance=-3 (60% true positive rate, 100% true negative rate) +layer 4, + reachable_distance=-3 gets 97% (78% true positive rate, 100% true negative rate) + reachable_distance=-4 gets 95% (75% true positive rate, 99% true negative rate) + reachable_distance=-5 gets 91% (56% true positive rate, 99% true negative rate) + reachable_distance -8 gets % (% true positive rate, % true negative rate) +layer 5, + reachable_distance -4 gets 95% (74% true positive rate, 100% true negative rate) + reachable_distance -5 gets 93% (73% true positive rate, 99% true negative rate) + reachable_distance -6 gets 90% (65% true positive rate, 98% true negative rate) + reachable_distance -7 gets 87% (60% true positive rate, 98% true negative rate) + reachable_distance -8 gets 84% (54% true positive rate, 97% true negative rate) + reachable_distance -10 gets 75% (50% true positive rate, 93% true negative rate) +layer 6, + reachable_distance -4 gets 95% (70% true positive rate, 99% true negative rate) + reachable_distance -8 gets 83% (53% true positive rate, 97% true negative rate) + +check if the model is doing a forward search from the start vertex: +layer 3, reachable_distance=+2 gets 87% (0% true positive rate, 100% true negative rate) (this predictor is just guessing 0 for all inputs) *if the probe doesn't include the start vertex* + if the probe does include the start vertex, it gets 87% accuracy (26% true positive rate, 100% true negative rate) (but this could just be identifying the start vertex and guessing 1 for only that vertex) +layer 4, reachable_distance=+2 gets 93% (42% true positive rate, 100% true negative rate) + if the probe does include the start vertex, it gets 92% accuracy (57% true positive rate, 100% true negative rate) +layer 5, reachable_distance=+2 gets 92% (57% true positive rate, 100% true negative rate) + +check if the model is doing a backward search from vertices at specific absolute positions, or from the kth vertex in the correct path, the probe gets 0% true positive rate and 100% true negative rate. + +probing results on source vertices: +layer 2, + reachable_distance=-1 gets 96% (0% true positive rate, 100% true negative rate) +layer 3, + reachable_distance=-1 gets 74% (79% true positive rate, 74% true negative rate) + reachable_distance=-2 gets 91% (32% true positive rate, 96% true negative rate) + +what if the model is doing a forward search from the start vertex: (the probe includes the start vertex) +layer 2, + reachable_distance=+1 gets 95% (82% true positive rate, 96% true negative rate) +layer 3, + reachable_distance=+1 gets 98% (59% true positive rate, 100% true negative rate) + reachable_distance=+2 gets 92% (36% true positive rate, 99% true negative rate) +layer 4, + reachable_distance=+2 gets 88% (0% true positive rate, 100% true negative rate) + reachable_distance=+3 gets 80% (2% true positive rate, 99% true negative rate) +layer 5, + reachable_distance=+3 gets 81% (0% true positive rate, 100% true negative rate) diff --git a/trace_circuit.py b/trace_circuit.py index 08132de..a85e548 100644 --- a/trace_circuit.py +++ b/trace_circuit.py @@ -520,30 +520,51 @@ def trace_activation(i, row, representations): def trace_activation_forward(representation, num_layers): representation = representation.clone().detach() representations = [representation.clone().detach()] + ff_representations = [] attn_representations = [] + bias_contribution = torch.zeros(representation[0].shape) for i in range(num_layers): layer_norm_matrix = self.model.transformers[i].ln_attn.weight.unsqueeze(0).repeat((n,1)) / torch.sqrt(torch.var(layer_inputs[i], dim=1, correction=0) + self.model.transformers[i].ln_attn.eps).unsqueeze(1).repeat((1,d)) + layer_norm_bias = -torch.mean(layer_inputs[i], dim=1).unsqueeze(1) * layer_norm_matrix + self.model.transformers[i].ln_attn.bias attn_representation = representation * layer_norm_matrix - attn_representations.append(attn_representation) + attn_representations.append((attn_representation, bias_contribution * layer_norm_matrix + layer_norm_bias)) representation += torch.matmul(torch.matmul(torch.matmul(attn_matrices[i], attn_representation), self.model.transformers[i].attn.proj_v.weight.T), self.model.transformers[i].attn.linear.weight.T) + # compute the contribution from bias terms in the attention layer + bias_contribution += torch.matmul(torch.matmul(torch.matmul(attn_matrices[i], bias_contribution * layer_norm_matrix), self.model.transformers[i].attn.proj_v.weight.T), self.model.transformers[i].attn.linear.weight.T) + bias_contribution += torch.matmul(torch.matmul(torch.matmul(attn_matrices[i], layer_norm_bias), self.model.transformers[i].attn.proj_v.weight.T), self.model.transformers[i].attn.linear.weight.T) + bias_contribution += torch.matmul(self.model.transformers[i].attn.proj_v.bias.unsqueeze(0), self.model.transformers[i].attn.linear.weight.T).repeat((n,1)) + bias_contribution += self.model.transformers[i].attn.linear.bias.unsqueeze(0).repeat((n,1)) + + ff_representations.append(representation.clone().detach()) + if self.model.transformers[i].ff: layer_norm_matrix = self.model.transformers[i].ln_ff.weight.unsqueeze(0).repeat((n,1)) / torch.sqrt(torch.var(ff_inputs[i], dim=1, correction=0) + self.model.transformers[i].ln_ff.eps).unsqueeze(1).repeat((1,d)) + layer_norm_bias = -torch.mean(ff_inputs[i], dim=1).unsqueeze(1) * layer_norm_matrix + self.model.transformers[i].ln_ff.bias ff0_output = self.model.transformers[i].ff[0](self.model.transformers[i].ln_ff(ff_inputs[i])) representation += torch.matmul((ff0_output > 0.0) * torch.matmul(representation * layer_norm_matrix, self.model.transformers[i].ff[0].weight.T), self.model.transformers[i].ff[3].weight.T) + + # compute the contribution from the bias terms in the FF layer + layer_norm_bias = -torch.mean(ff_inputs[i], dim=1).unsqueeze(1) * layer_norm_matrix + self.model.transformers[i].ln_ff.bias + bias_contribution += torch.matmul((ff0_output > 0.0) * torch.matmul(bias_contribution * layer_norm_matrix, self.model.transformers[i].ff[0].weight.T), self.model.transformers[i].ff[3].weight.T) + bias_contribution += torch.matmul((ff0_output > 0.0) * torch.matmul(layer_norm_bias, self.model.transformers[i].ff[0].weight.T), self.model.transformers[i].ff[3].weight.T) + bias_contribution += torch.matmul((ff0_output > 0.0) * self.model.transformers[i].ff[0].bias.unsqueeze(0), self.model.transformers[i].ff[3].weight.T) + bias_contribution += self.model.transformers[i].ff[3].bias.unsqueeze(0).repeat((n,1)) representations.append(representation.clone().detach()) - return attn_representations, representations + return attn_representations, ff_representations, representations def check_copyr(i, dst, src, attn_inputs, attn_matrices, attn_representations): attn_input = attn_inputs[i] + attn_representation, attn_bias_contribution = attn_representations[i] if not quiet: print('Attention layer {} is copying row {} into row {} with weight {} because:'.format(i,src,dst,attn_matrices[i][dst,src])) A = torch.matmul(self.model.transformers[i].attn.proj_q.weight.T, self.model.transformers[i].attn.proj_k.weight) - left = torch.matmul(attn_representations[i][:,dst,:], A) - products = torch.matmul(left, attn_representations[i][:,src,:].T) - #left = torch.matmul(torch.cat((attn_representations[i][:,dst,:], torch.ones((2*n,1))), 1), A_matrices[i]) - #products = torch.matmul(left, torch.cat((attn_representations[i][:,src,:], torch.ones((2*n,1))), 1).T) + left = torch.matmul(attn_representation[:,dst,:], A) + products = torch.matmul(left, attn_representation[:,src,:].T) + bias_product = torch.dot(attn_bias_contribution[dst,:],attn_bias_contribution[src,:]) + #left = torch.matmul(torch.cat((attn_representation[:,dst,:], torch.ones((2*n,1))), 1), A_matrices[i]) + #products = torch.matmul(left, torch.cat((attn_representation[:,src,:], torch.ones((2*n,1))), 1).T) import pdb; pdb.set_trace() #trace_activation(5, 35, [(40, attn_inputs[5][35,40]), (125, attn_inputs[5][35,125]), (128, attn_inputs[5][35,128])]) @@ -551,8 +572,10 @@ def check_copyr(i, dst, src, attn_inputs, attn_matrices, attn_representations): for i in range(n): representation[i,i,input[i]] = 1.0 representation[n+i,i,d-n+i] = 1.0 - attn_representations, representations = trace_activation_forward(representation, len(self.model.transformers)) - check_copyr(4, 35, 41, attn_inputs, attn_matrices, attn_representations) + attn_representations, ff_representations, representations = trace_activation_forward(representation, len(self.model.transformers)) + check_copyr(5, 89, 35, attn_inputs, attn_matrices, attn_representations) + + other_layer_inputs, other_attn_inputs, other_attn_pre_softmax, other_attn_matrices, other_v_outputs, other_attn_linear_inputs, other_attn_outputs, other_A_matrices, other_ff_inputs, other_ff_parameters, other_prediction = self.forward(x, mask, 0, False, [(4, 35, layer_inputs[4][35,:] - representations[4][30,35,:] + representations[3][30,35,:])]) import pdb; pdb.set_trace() PADDING_TOKEN = (n - 5) // 3 + 3 @@ -805,12 +828,12 @@ def check_copy(i, row, j): #input = [46, 46, 46, 46, 46, 46, 46, 45, 31, 39, 45, 42, 4, 45, 21, 7, 45, 19, 20, 45, 13, 22, 45, 7, 42, 45, 20, 21, 45, 17, 19, 45, 17, 31, 45, 10, 14, 45, 39, 10, 45, 14, 13, 47, 17, 4, 44, 17] #input = [62, 62, 62, 62, 62, 61, 15, 8, 61, 11, 18, 61, 9, 5, 61, 19, 14, 61, 19, 17, 61, 1, 11, 61, 6, 7, 61, 10, 3, 61, 2, 1, 61, 13, 10, 61, 12, 4, 61, 17, 16, 61, 7, 12, 61, 14, 2, 61, 3, 9, 61, 16, 15, 61, 18, 6, 61, 8, 13, 63, 19, 4, 60, 19] #input = [44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 43, 21, 40, 43, 21, 22, 43, 22, 34, 43, 13, 3, 43, 31, 2, 43, 24, 13, 43, 4, 41, 43, 30, 31, 43, 17, 15, 43, 34, 38, 43, 3, 28, 43, 18, 17, 43, 14, 24, 43, 2, 4, 43, 32, 30, 43, 40, 18, 43, 15, 14, 43, 38, 32, 45, 21, 28, 42, 21] - input = [31, 31, 31, 31, 31, 31, 31, 30, 7, 23, 30, 9, 22, 30, 6, 4, 30, 6, 10, 30, 25, 19, 30, 17, 9, 30, 17, 16, 30, 1, 14, 30, 11, 21, 30, 26, 1, 30, 12, 11, 30, 14, 6, 30, 15, 25, 30, 24, 28, 30, 4, 17, 30, 19, 8, 30, 27, 26, 30, 27, 12, 30, 27, 5, 30, 22, 24, 30, 8, 3, 30, 18, 15, 30, 3, 7, 30, 3, 10, 30, 3, 2, 30, 21, 18, 32, 27, 28, 29, 27] + input = [31, 31, 31, 31, 31, 31, 31, 30, 7, 23, 30, 9, 22, 30, 6, 4, 30, 6, 10, 30, 25, 19, 30, 17, 9, 30, 17, 16, 30, 1, 14, 30, 11, 21, 30, 26, 1, 30, 12, 11, 30, 14, 6, 30, 15, 25, 30, 4, 17, 30, 24, 28, 30, 19, 8, 30, 27, 26, 30, 27, 12, 30, 27, 5, 30, 22, 24, 30, 8, 3, 30, 18, 15, 30, 3, 7, 30, 3, 10, 30, 3, 2, 30, 21, 18, 32, 27, 28, 29, 27] input = torch.LongTensor(input).to(device) other_input = input.clone().detach() - #other_input[-3] = 7 - other_input[next(i for i in range(len(input)) if input[i] == 1 and input[i+1] == 14) + 1] = 21 - other_input[next(i for i in range(len(input)) if input[i] == 11 and input[i+1] == 21) + 1] = 14 + other_input[-3] = 7 + #other_input[next(i for i in range(len(input)) if input[i] == 1 and input[i+1] == 14) + 1] = 21 + #other_input[next(i for i in range(len(input)) if input[i] == 11 and input[i+1] == 21) + 1] = 14 #other_input[next(i for i in range(len(input)) if input[i] == 6 and input[i+1] == 4) + 0] = 18 #other_input[next(i for i in range(len(input)) if input[i] == 6 and input[i+1] == 4) + 1] = 15 #other_input[next(i for i in range(len(input)) if input[i] == 18 and input[i+1] == 15) + 0] = 6 diff --git a/train_probe.py b/train_probe.py index 4dcf552..4e6f2a0 100644 --- a/train_probe.py +++ b/train_probe.py @@ -1,10 +1,12 @@ -from random import seed, randrange, Random +from random import seed, randrange, getstate, Random import numpy as np import torch from torch import nn, LongTensor, FloatTensor from torch.nn import BCEWithLogitsLoss, Sigmoid from torch.utils.data import DataLoader from Sophia import SophiaG +from sys import stdout +from train import binomial_confidence_int import multiprocessing def build_module(name): @@ -41,7 +43,10 @@ def __init__(self, tfm_model, probe_layer): for param in tfm_model.parameters(): param.requires_grad = False self.probe_layer = probe_layer - self.decoder = nn.Linear(hidden_dim, hidden_dim) + n = tfm_model.positional_embedding.size(0) + self.decoder = nn.Linear(hidden_dim + n * hidden_dim, 1) + if probe_layer > len(tfm_model.transformers): + raise Exception('probe_layer must be <= number of layers') def to(self, device): super().to(device) @@ -64,6 +69,7 @@ def forward(self, x: torch.Tensor): pos = self.model.positional_embedding.unsqueeze(0).expand(x.shape[0], -1, -1) x = torch.cat((x, pos), -1) x = self.model.dropout_embedding(x) + input = x.clone() '''print("embedded input:") for i in range(x.size(0)): @@ -80,7 +86,8 @@ def forward(self, x: torch.Tensor): if i + 1 == self.probe_layer: break - return self.decoder(x) + dec_input = torch.cat((x, input.reshape(input.size(0), input.size(1) * input.size(2)).unsqueeze(1).repeat((1,input.size(1),1))), dim=2) + return self.decoder(dec_input) def evaluate_decoder(model, max_input_size): @@ -156,7 +163,10 @@ def evaluate_decoder(model, max_input_size): device = torch.device('cuda') tfm_model, _, _, _ = torch.load(argv[1], map_location=device) - model = TransformerProber(tfm_model, probe_layer=1) + for transformer in tfm_model.transformers: + if not hasattr(transformer, 'pre_ln'): + transformer.pre_ln = True + model = TransformerProber(tfm_model, probe_layer=4) model.to(device) suffix = argv[1][argv[1].index('inputsize')+len('inputsize'):] @@ -164,14 +174,15 @@ def evaluate_decoder(model, max_input_size): # reserve some data for validation lookahead_steps = 10 - reachable_distance = 4 - start_vertex_index = max_input_size-1 + reachable_distance = 8 + start_vertex_index = 3 + exclude_start_vertex = True reserved_inputs = set() # we are doing streaming training, so use an IterableDataset from itertools import cycle from threading import Lock - STREAMING_BLOCK_SIZE = 2 ** 18 + STREAMING_BLOCK_SIZE = 2 ** 13 NUM_DATA_WORKERS = 2 seed_generator = Random(seed_value) seed_generator_lock = Lock() @@ -206,8 +217,7 @@ def process_data(self, start): torch.manual_seed(new_seed) np.random.seed(new_seed) - inputs, outputs, num_collisions = generator.generate_reachable_training_set(max_input_size, BATCH_SIZE, lookahead_steps, reserved_inputs, 1, reachable_distance, start_vertex_index) - import pdb; pdb.set_trace() + inputs, outputs, num_collisions = generator.generate_reachable_training_set(max_input_size, BATCH_SIZE, lookahead_steps, reserved_inputs, 1, reachable_distance, start_vertex_index, exclude_start_vertex) if num_collisions != 0: with self.collisions_lock: self.total_collisions.value += num_collisions @@ -223,17 +233,19 @@ def __iter__(self): return self.process_data(self.offset + worker_id) epoch = 0 - BATCH_SIZE = 2 ** 11 + BATCH_SIZE = 2 ** 9 iterable_dataset = StreamingDataset(epoch * STREAMING_BLOCK_SIZE // BATCH_SIZE) train_loader = DataLoader(iterable_dataset, batch_size=None, num_workers=NUM_DATA_WORKERS, pin_memory=True, prefetch_factor=8) loss_func = BCEWithLogitsLoss(reduction='mean') - optimizer = SophiaG((p for p in model.parameters() if p.requires_grad), lr=1.0e-3) + optimizer = SophiaG((p for p in model.parameters() if p.requires_grad), lr=1.0e-4) log_interval = 1 - eval_interval = 20 - save_interval = 100 + eval_interval = 1 + save_interval = 10 + epoch_loss = 0.0 + num_batches = 0 while True: for batch in train_loader: model.train() @@ -245,9 +257,8 @@ def __iter__(self): logits = model(input) # only take the predictions on source vertices - import pdb; pdb.set_trace() - logits = logits[:,range(2,max_input_size-5,3),:] - output = output[:,range(2,max_input_size-5,3),:] + logits = logits[:,range(2,max_input_size-5,3),-1] + output = output[:,range(2,max_input_size-5,3)] loss_val = loss_func(logits, output) epoch_loss += loss_val.item() @@ -268,14 +279,24 @@ def __iter__(self): if epoch % eval_interval == 0: model.eval() - logits, _ = model(input) - training_acc = torch.sum(torch.argmax(logits[:,-1,:],dim=1) == output).item() / output.size(0) - print("training accuracy: %.2f±%.2f" % (training_acc, binomial_confidence_int(training_acc, output.size(0)))) - del input, output + logits = model(input) + training_preds = (logits[:,range(2,max_input_size-5,3),-1] > 0.5) + training_labels = (output == 1) + total_preds = training_preds.size(0) * training_preds.size(1) + training_acc = torch.sum(training_preds == training_labels).item() / total_preds + try: + print("training accuracy: %.2f±%.2f" % (training_acc, binomial_confidence_int(training_acc, total_preds))) + true_positive_acc = torch.sum(training_preds[output == 1]).item() / torch.sum(output == 1).item() + print("true positive rate: %.2f±%.2f" % (true_positive_acc, binomial_confidence_int(true_positive_acc, torch.sum(output == 1).item()))) + true_negative_acc = torch.sum(training_preds[output == 0] == 0).item() / torch.sum(output == 0).item() + print("true negative rate: %.2f±%.2f" % (true_negative_acc, binomial_confidence_int(true_negative_acc, torch.sum(output == 0).item()))) + del input, output + except ZeroDivisionError: + pass stdout.flush() - test_acc,test_loss,_ = evaluate_model(model, eval_inputs, eval_outputs) - print("test accuracy = %.2f±%.2f, test loss = %f" % (test_acc, binomial_confidence_int(test_acc, 1000), test_loss)) + #test_acc,test_loss,_ = evaluate_model(model, eval_inputs, eval_outputs) + #print("test accuracy = %.2f±%.2f, test loss = %f" % (test_acc, binomial_confidence_int(test_acc, 1000), test_loss)) stdout.flush() epoch += 1