Skip to content

Commit

Permalink
Added options to train on examples from the star graph distribution, …
Browse files Browse the repository at this point in the history
…as well as an ablated form of the balanced distribution.
  • Loading branch information
Abulhair Saparov committed Nov 19, 2024
1 parent 049afe8 commit 8a181f7
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 58 deletions.
45 changes: 42 additions & 3 deletions analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,8 @@ def print_graph(input):
print('Start: ' + str(start) + ', Goal: ' + str(goal))

path_prefix_index = np.nonzero(input == PATH_PREFIX_TOKEN)[0][-1].item()
print('Path: ' + str(input[(path_prefix_index+1):]))
path = np.trim_zeros(input[np.nonzero(input == PATH_PREFIX_TOKEN)[0][0]:] - PATH_PREFIX_TOKEN) + PATH_PREFIX_TOKEN
print('Path: ' + ' '.join(['P' if v == PATH_PREFIX_TOKEN else str(v) for v in path]))


def do_evaluate_model(filepath, star_distribution=False, max_backtrack_distance=None):
Expand All @@ -322,7 +323,13 @@ def do_evaluate_model(filepath, star_distribution=False, max_backtrack_distance=
training_max_lookahead = int(suffix[:suffix.index('_')])
max_lookahead = ((max_input_size - 5) // 3 - 1) // 2

is_dfs = 'dfs' in dirname
task = 'dfs' in dirname
if '_dfs_' in dirname:
task = 'dfs'
elif '_si_' in dirname:
task = 'si'
else:
task = 'search'
if max_backtrack_distance == None:
max_backtrack_distance = (max_input_size - 4) // 4 - 1

Expand All @@ -332,6 +339,8 @@ def do_evaluate_model(filepath, star_distribution=False, max_backtrack_distance=
if not hasattr(transformer, 'pre_ln'):
transformer.pre_ln = True

PATH_PREFIX_TOKEN = (max_input_size-5) // 3 + 1

seed_generator = Random(training_seed)
seed_values = []

Expand All @@ -345,7 +354,7 @@ def get_seed(index):
NUM_TEST_SAMPLES = 1000
reserved_inputs = set()
test_accuracies = []
if is_dfs:
if task == 'dfs':
for backtrack_distance in [-1] + list(range(0, max_backtrack_distance + 1)):
generator.set_seed(get_seed(1))
inputs,outputs,labels,_ = generator.generate_dfs_training_set(max_input_size, NUM_TEST_SAMPLES, reserved_inputs, backtrack_distance, False, False, True)
Expand All @@ -361,6 +370,36 @@ def get_seed(index):
import pdb; pdb.set_trace()
print("Test accuracy = %.2f±%.2f, test loss = %f" % (test_acc, confidence_int, test_loss))
test_accuracies.append((test_acc, confidence_int, test_loss))
elif task == 'si':
max_edges = (max_input_size - 2) // 6
max_frontier_size = (max_edges + 1) // 2
max_branch_size = max_edges
frontier_branches = []
for frontier_size in range(1, max_frontier_size + 1):
for branch_size in range(1, max_branch_size + 1):
if frontier_size + branch_size > max_edges + 1:
continue
frontier_branches.append((frontier_size, branch_size))
for frontier_size, branch_size in frontier_branches:
generator.set_seed(get_seed(1))
inputs,outputs,labels,_ = generator.generate_si_training_set(max_input_size, NUM_TEST_SAMPLES, reserved_inputs, frontier_size, branch_size, False, True)
test_acc,test_loss,predictions = evaluate_model(model, inputs, outputs)
confidence_int = binomial_confidence_int(test_acc, NUM_TEST_SAMPLES)
predictions = np.array(predictions.cpu())
'''print("Mistaken inputs:")
incorrect_indices,_ = np.nonzero(np.take_along_axis(outputs, predictions[:,None], axis=1) == 0)
np.set_printoptions(threshold=10_000)
for incorrect_index in incorrect_indices:
print_graph(inputs[incorrect_index, :])
print("Expected answer: {}, predicted answer: {} (label: {})\n".format(np.nonzero(outputs[incorrect_index])[0], predictions[incorrect_index], labels[incorrect_index]))'''
selection_inputs = (inputs[:,-1] == PATH_PREFIX_TOKEN)
selection_acc = np.sum(np.take_along_axis(outputs[selection_inputs], predictions[selection_inputs,None], axis=1)) / np.sum(selection_inputs)
inference_acc = np.sum(np.take_along_axis(outputs[~selection_inputs], predictions[~selection_inputs,None], axis=1)) / np.sum(~selection_inputs)
selection_confidence_int = binomial_confidence_int(selection_acc, np.sum(selection_inputs))
inference_confidence_int = binomial_confidence_int(inference_acc, np.sum(~selection_inputs))
print("(%u,%u) Test accuracy = %.2f±%.2f, test loss = %f, selection accuracy = %.2f±%.2f, inference accuracy = %.2f±%.2f" % (frontier_size, branch_size, test_acc, confidence_int, test_loss, selection_acc, selection_confidence_int, inference_acc, inference_confidence_int))
#import pdb; pdb.set_trace()
test_accuracies.append((test_acc, confidence_int, test_loss))
elif star_distribution:
for spoke_length in range(1, max_lookahead + 1):
max_spoke_count = ((max_input_size - 5) // 3 - 1) // spoke_length
Expand Down
58 changes: 34 additions & 24 deletions generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ bool has_cycles(array<node>& vertices) {
return false;
}

bool generate_graph_with_lookahead(array<node>& vertices, node*& start, node*& end, unsigned int num_vertices, unsigned int max_num_parents, unsigned int max_vertex_id, unsigned int lookahead, unsigned int num_paths)
bool generate_graph_with_lookahead(array<node>& vertices, node*& start, node*& end, unsigned int num_vertices, unsigned int max_num_parents, unsigned int max_vertex_id, unsigned int lookahead, unsigned int num_paths, unsigned int max_prefix_vertices)
{
num_vertices = std::max(std::max(2u, num_vertices), 1 + num_paths * lookahead);

Expand Down Expand Up @@ -215,7 +215,7 @@ bool generate_graph_with_lookahead(array<node>& vertices, node*& start, node*& e
}
}

unsigned int num_prefix_vertices = randrange(num_vertices - index + 1);
unsigned int num_prefix_vertices = randrange(min(max_prefix_vertices + 1, num_vertices - index + 1));
node* prev_vertex = &vertices[0];
for (unsigned int i = 0; i < num_prefix_vertices; i++) {
vertices[index].children.add(prev_vertex);
Expand Down Expand Up @@ -326,13 +326,13 @@ bool generate_graph_with_lookahead(array<node>& vertices, node*& start, node*& e
return true;
}

bool generate_example(array<node>& vertices, node*& start, node*& end, array<array<node*>>& paths, unsigned int num_vertices, unsigned int max_num_parents, unsigned int max_vertex_id, bool get_shortest_paths, int lookahead, unsigned int num_paths)
bool generate_example(array<node>& vertices, node*& start, node*& end, array<array<node*>>& paths, unsigned int num_vertices, unsigned int max_num_parents, unsigned int max_vertex_id, bool get_shortest_paths, int lookahead, unsigned int num_paths, unsigned int max_prefix_vertices)
{
if (lookahead == -1) {
if (!generate_graph(vertices, start, end, num_vertices, max_num_parents, max_vertex_id))
return false;
} else {
if (!generate_graph_with_lookahead(vertices, start, end, num_vertices, max_num_parents, max_vertex_id, lookahead, num_paths))
if (!generate_graph_with_lookahead(vertices, start, end, num_vertices, max_num_parents, max_vertex_id, lookahead, num_paths, max_prefix_vertices))
return false;
}

Expand Down Expand Up @@ -507,7 +507,7 @@ bool has_path(const node* start, const node* end)
return false;
}

py::tuple generate_training_set(const unsigned int max_input_size, const uint64_t dataset_size, const int max_lookahead, const unsigned int max_edges, const py::object& reserved_inputs, const int distance_from_start, const bool quiet=false)
py::tuple generate_training_set(const unsigned int max_input_size, const uint64_t dataset_size, const int max_lookahead, const unsigned int max_edges, const py::object& reserved_inputs, const int distance_from_start, const int max_prefix_vertices, const bool quiet=false)
{
const unsigned int QUERY_PREFIX_TOKEN = (max_input_size-5) / 3 + 4;
const unsigned int PADDING_TOKEN = (max_input_size-5) / 3 + 3;
Expand Down Expand Up @@ -553,7 +553,7 @@ py::tuple generate_training_set(const unsigned int max_input_size, const uint64_
while (true) {
if (max_lookahead == -1) {
unsigned int num_vertices = randrange(3, (max_input_size - 5) / 3);
if (!generate_example(g, start, end, paths, num_vertices, 4, (max_input_size - 5) / 3, true, -1, 0)) {
if (!generate_example(g, start, end, paths, num_vertices, 4, (max_input_size - 5) / 3, true, -1, 0, max_prefix_vertices == -1 ? max_input_size : max_prefix_vertices)) {
for (node& n : g) core::free(n);
for (array<node*>& a : paths) core::free(a);
g.length = 0; paths.length = 0;
Expand All @@ -575,7 +575,7 @@ py::tuple generate_training_set(const unsigned int max_input_size, const uint64_
}

unsigned int num_vertices = std::min(std::min(lookahead * num_paths + 1 + randrange(0, 6), (max_input_size-5) / 3), max_edges + 1);
if (!generate_example(g, start, end, paths, num_vertices, 4, (max_input_size - 5) / 3, true, lookahead, num_paths)) {
if (!generate_example(g, start, end, paths, num_vertices, 4, (max_input_size - 5) / 3, true, lookahead, num_paths, max_prefix_vertices == -1 ? max_input_size : max_prefix_vertices)) {
for (node& n : g) core::free(n);
for (array<node*>& a : paths) core::free(a);
g.length = 0; paths.length = 0;
Expand Down Expand Up @@ -726,7 +726,7 @@ py::array_t<int64_t, py::array::c_style> lookahead_histogram(const unsigned int
array<array<node*>> paths(8);
while (true) {
unsigned int num_vertices = randrange(3, (max_input_size - 5) / 3);
if (!generate_example(g, start, end, paths, num_vertices, 4, (max_input_size - 5) / 3, true, -1, 0)) {
if (!generate_example(g, start, end, paths, num_vertices, 4, (max_input_size - 5) / 3, true, -1, 0, -1)) {
for (node& n : g) core::free(n);
for (array<node*>& a : paths) core::free(a);
g.length = 0; paths.length = 0;
Expand Down Expand Up @@ -848,7 +848,7 @@ py::tuple generate_reachable_training_set(const unsigned int max_input_size, con
}

unsigned int num_vertices = std::min(std::min(lookahead * num_paths + 1 + randrange(0, 6), (max_input_size-5) / 3), max_edges + 1);
if (!generate_example(g, start, end, paths, num_vertices, 4, max_vertex_id, true, lookahead, num_paths)) {
if (!generate_example(g, start, end, paths, num_vertices, 4, max_vertex_id, true, lookahead, num_paths, -1)) {
for (node& n : g) core::free(n);
for (array<node*>& a : paths) core::free(a);
g.length = 0; paths.length = 0;
Expand Down Expand Up @@ -1741,32 +1741,38 @@ bool generate_si_example(array<node>& vertices, const node*& start, const node*&
pair<node*, const node*> entry = reachability_stack.pop();
node* next = entry.key;
const node* parent = entry.value;
if (reverse_ptrs.contains(next))
if (next->id > end_index || reverse_ptrs.contains(next))
continue;
reverse_ptrs.put(next, parent);
for (node* child : next->children)
if (!reachability_stack.contains(make_pair<node*, const node*>(child, next)))
reachability_stack.add(make_pair<node*, const node*>(child, next));
}
const node* current;
const node* parent;
if (reverse_ptrs.contains(&vertices[end_index])) {
/* make sure none of the edges on the path from `start` to `end` are removable */
const node* current = &vertices[end_index];
const node* parent = reverse_ptrs.get(current);
while (true) {
pair<unsigned int, unsigned int> entry = make_pair(parent->id, current->id);
unsigned int index = removable_edges.index_of(entry);
if (index != removable_edges.length)
removable_edges.remove(index);
if (path.contains(parent) || parent == start)
break;
current = parent;
parent = reverse_ptrs.get(current);
}
current = &vertices[end_index];
parent = reverse_ptrs.get(current);
} else {
reverse_ptrs.put(&vertices[start_index], nullptr);
node* new_parent = choice(reverse_ptrs.keys, reverse_ptrs.size);
new_parent->children.add(&vertices[end_index]);
vertices[end_index].parents.add(new_parent);
total_edge_count++; new_edge_count++;

current = &vertices[end_index];
parent = new_parent;
}
/* make sure none of the edges on the path from `start` to `end` are removable */
while (true) {
pair<unsigned int, unsigned int> entry = make_pair(parent->id, current->id);
unsigned int index = removable_edges.index_of(entry);
if (index != removable_edges.length)
removable_edges.remove(index);
if (parent == start)
break;
current = parent;
parent = reverse_ptrs.get(current);
}

/* remove edges to avoid generating a graph with too many edges */
Expand Down Expand Up @@ -2007,7 +2013,11 @@ py::tuple generate_si_training_set(const unsigned int max_input_size, const uint
const node* current_node = &g[current_node_index];
unsigned int branch_size = current_node->children.length;

bool is_selection_step = (randrange(2) == 0);
bool is_selection_step;
if (path.length == 0)
is_selection_step = false;
else
is_selection_step = (randrange(2) == 0);
if (3*(path.length/2) + (is_selection_step ? 1 : 2) > 3*(max_edges - 1) + 1) {
/* we have just barely too many edges */
for (node& n : g) core::free(n);
Expand Down
Loading

0 comments on commit 8a181f7

Please sign in to comment.