Skip to content

Commit

Permalink
Modifying the linear probe training code.
Browse files Browse the repository at this point in the history
  • Loading branch information
asaparov committed Mar 12, 2024
1 parent defa724 commit df33c9b
Show file tree
Hide file tree
Showing 3 changed files with 281 additions and 88 deletions.
146 changes: 146 additions & 0 deletions generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,153 @@ py::tuple generate_training_set(const unsigned int max_input_size, const uint64_
return py::make_tuple(inputs, outputs, valid_outputs, num_collisions);
}

py::tuple generate_reachable_training_set(const unsigned int max_input_size, const uint64_t dataset_size, const unsigned int lookahead, const py::object& reserved_inputs, const int distance_from_start, const int reachable_distance, const bool start_from_goal)
{
const unsigned int QUERY_PREFIX_TOKEN = (max_input_size-5) / 3 + 4;
const unsigned int PADDING_TOKEN = (max_input_size-5) / 3 + 3;
const unsigned int EDGE_PREFIX_TOKEN = (max_input_size-5) / 3 + 2;
const unsigned int PATH_PREFIX_TOKEN = (max_input_size-5) / 3 + 1;

unsigned int num_generated = 0;
unsigned int num_collisions = 0;
size_t input_shape[2]{dataset_size, max_input_size};
size_t output_shape[2]{dataset_size, max_input_size};
py::array_t<int64_t, py::array::c_style> inputs(input_shape);
py::array_t<int64_t, py::array::c_style> outputs(output_shape);
auto inputs_mem = inputs.mutable_unchecked<2>();
auto outputs_mem = outputs.mutable_unchecked<2>();
py::list valid_outputs;

while (num_generated < dataset_size) {
array<node> g(32);
node* start; node* end;
array<array<node*>> paths(8);
while (true) {
unsigned int num_paths;
if (lookahead == 0) {
num_paths = randrange(1, 3);
} else {
unsigned int max_num_paths = ((max_input_size - 5) / 3 - 1) / lookahead;
num_paths = randrange(2, max_num_paths + 1);
}

unsigned int num_vertices = std::min(lookahead * num_paths + 1 + randrange(0, 6), (max_input_size-5) / 3);
if (!generate_example(g, start, end, paths, num_vertices, 4, (max_input_size - 5) / 3, true, lookahead, num_paths)) {
for (node& n : g) core::free(n);
for (array<node*>& a : paths) core::free(a);
g.length = 0; paths.length = 0;
continue;
}
unsigned int shortest_path_length = paths[0].length;
for (unsigned int i = 1; i < paths.length; i++)
if (paths[i].length < shortest_path_length)
shortest_path_length = paths[i].length;
if (shortest_path_length > 1)
break;
for (node& n : g) core::free(n);
for (array<node*>& a : paths) core::free(a);
g.length = 0; paths.length = 0;
}

array<pair<unsigned int, unsigned int>> edges(8);
for (node& vertex : g)
for (node* child : vertex.children)
edges.add(make_pair(vertex.id, child->id));
if (edges.length * 3 + 4 > max_input_size) {
for (node& n : g) core::free(n);
for (array<node*>& a : paths) core::free(a);
g.length = 0; paths.length = 0;
continue;
}
shuffle(edges);

/* compute the set of reachable vertices */
array<unsigned int> reachable(16);
array<pair<unsigned int, unsigned int>> stack(16);
unsigned int start_vertex = start_from_goal ? end->id : start->id;
stack.add(make_pair(start_vertex, 0u));
while (stack.length != 0) {
pair<unsigned int, unsigned int> entry = stack.pop();
unsigned int current_vertex = entry.key;
unsigned int current_distance = entry.value;
if (!reachable.contains(current_vertex))
reachable.add(current_vertex);
if (reachable_distance > 0 && current_distance + 1 <= (unsigned int) reachable_distance) {
for (node* child : g[current_vertex].children)
stack.add(make_pair(child->id, current_distance + 1));
} else if (reachable_distance < 0 && current_distance + 1 <= (unsigned int) -reachable_distance) {
for (node* parent : g[current_vertex].parents)
stack.add(make_pair(parent->id, current_distance + 1));
}
}

array<unsigned int> prefix(max_input_size);
for (auto& entry : edges) {
prefix[prefix.length++] = EDGE_PREFIX_TOKEN;
prefix[prefix.length++] = entry.key;
prefix[prefix.length++] = entry.value;
}
prefix[prefix.length++] = QUERY_PREFIX_TOKEN;
prefix[prefix.length++] = start->id;
prefix[prefix.length++] = end->id;
prefix[prefix.length++] = PATH_PREFIX_TOKEN;

for (const array<node*>& path : paths) {
if (path.length == 1)
continue;
for (unsigned int j = 1; j < path.length; j++) {
if (distance_from_start != -1 && j != (unsigned int) distance_from_start)
continue;
array<unsigned int> example(prefix.length + j);
for (unsigned int i = 0; i < prefix.length; i++)
example[i] = prefix[i];
for (unsigned int i = 0; i < j; i++)
example[prefix.length + i] = path[i]->id;
example.length = prefix.length + j;
if (example.length > max_input_size)
continue;

array<node*> useful_steps(8);
for (node* v : path[j-1]->children)
if (has_path(v, end)) useful_steps.add(v);

/* check if this input is reserved */
py::object contains = reserved_inputs.attr("__contains__");
py::tuple example_tuple(example.length);
for (unsigned int i = 0; i < example.length; i++)
example_tuple[i] = example[i];
if (contains(example_tuple).is(py_true)) {
num_collisions += 1;
continue;
}

for (unsigned int i = 0; i < max_input_size - example.length; i++)
inputs_mem(num_generated, i) = PADDING_TOKEN;
for (unsigned int i = 0; i < example.length; i++)
inputs_mem(num_generated, max_input_size - example.length + i) = example[i];
for (unsigned int i = 0; i < max_input_size - example.length; i++)
outputs_mem(num_generated, i) = 0;
for (unsigned int i = 0; i < example.length; i++)
outputs_mem(num_generated, max_input_size - example.length + i) = reachable.contains(example[i]) ? 1 : 0;
num_generated++;
if (num_generated == dataset_size)
break;
}
if (num_generated == dataset_size)
break;
}

for (node& n : g) core::free(n);
for (array<node*>& a : paths) core::free(a);
g.length = 0; paths.length = 0;
continue;
}

return py::make_tuple(inputs, outputs, num_collisions);
}

PYBIND11_MODULE(generator, m) {
m.def("generate_training_set", &generate_training_set);
m.def("generate_reachable_training_set", &generate_reachable_training_set);
m.def("set_seed", &core::set_seed);
}
7 changes: 4 additions & 3 deletions trace_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,8 @@ def trace_activation(i, row, representations):
print(' Row {}, element {} has activation {}'.format(src_row, element, activation))

def trace_activation_forward(representation, num_layers):
representations = [representation]
representation = representation.clone().detach()
representations = [representation.clone().detach()]
attn_representations = []
for i in range(num_layers):
layer_norm_matrix = self.model.transformers[i].ln_attn.weight.unsqueeze(0).repeat((n,1)) / torch.sqrt(torch.var(layer_inputs[i], dim=1, correction=0) + self.model.transformers[i].ln_attn.eps).unsqueeze(1).repeat((1,d))
Expand All @@ -530,7 +531,7 @@ def trace_activation_forward(representation, num_layers):
layer_norm_matrix = self.model.transformers[i].ln_ff.weight.unsqueeze(0).repeat((n,1)) / torch.sqrt(torch.var(ff_inputs[i], dim=1, correction=0) + self.model.transformers[i].ln_ff.eps).unsqueeze(1).repeat((1,d))
ff0_output = self.model.transformers[i].ff[0](self.model.transformers[i].ln_ff(ff_inputs[i]))
representation += torch.matmul((ff0_output > 0.0) * torch.matmul(representation * layer_norm_matrix, self.model.transformers[i].ff[0].weight.T), self.model.transformers[i].ff[3].weight.T)
representations.append(representation)
representations.append(representation.clone().detach())

return attn_representations, representations

Expand All @@ -551,7 +552,7 @@ def check_copyr(i, dst, src, attn_inputs, attn_matrices, attn_representations):
representation[i,i,input[i]] = 1.0
representation[n+i,i,d-n+i] = 1.0
attn_representations, representations = trace_activation_forward(representation, len(self.model.transformers))
check_copyr(2, 65, 47, attn_inputs, attn_matrices, attn_representations)
check_copyr(4, 35, 41, attn_inputs, attn_matrices, attn_representations)
import pdb; pdb.set_trace()

PADDING_TOKEN = (n - 5) // 3 + 3
Expand Down
Loading

0 comments on commit df33c9b

Please sign in to comment.