From 8a181f7f7f6a0540031312c040344d504c4af5f1 Mon Sep 17 00:00:00 2001 From: Abulhair Saparov Date: Mon, 18 Nov 2024 21:32:15 -0500 Subject: [PATCH] Added options to train on examples from the star graph distribution, as well as an ablated form of the balanced distribution. --- analyze.py | 45 ++++++++++++++++++-- generator.cpp | 58 ++++++++++++++----------- train.py | 114 ++++++++++++++++++++++++++++++++++++-------------- 3 files changed, 159 insertions(+), 58 deletions(-) diff --git a/analyze.py b/analyze.py index 8882330..21fdfde 100644 --- a/analyze.py +++ b/analyze.py @@ -304,7 +304,8 @@ def print_graph(input): print('Start: ' + str(start) + ', Goal: ' + str(goal)) path_prefix_index = np.nonzero(input == PATH_PREFIX_TOKEN)[0][-1].item() - print('Path: ' + str(input[(path_prefix_index+1):])) + path = np.trim_zeros(input[np.nonzero(input == PATH_PREFIX_TOKEN)[0][0]:] - PATH_PREFIX_TOKEN) + PATH_PREFIX_TOKEN + print('Path: ' + ' '.join(['P' if v == PATH_PREFIX_TOKEN else str(v) for v in path])) def do_evaluate_model(filepath, star_distribution=False, max_backtrack_distance=None): @@ -322,7 +323,13 @@ def do_evaluate_model(filepath, star_distribution=False, max_backtrack_distance= training_max_lookahead = int(suffix[:suffix.index('_')]) max_lookahead = ((max_input_size - 5) // 3 - 1) // 2 - is_dfs = 'dfs' in dirname + task = 'dfs' in dirname + if '_dfs_' in dirname: + task = 'dfs' + elif '_si_' in dirname: + task = 'si' + else: + task = 'search' if max_backtrack_distance == None: max_backtrack_distance = (max_input_size - 4) // 4 - 1 @@ -332,6 +339,8 @@ def do_evaluate_model(filepath, star_distribution=False, max_backtrack_distance= if not hasattr(transformer, 'pre_ln'): transformer.pre_ln = True + PATH_PREFIX_TOKEN = (max_input_size-5) // 3 + 1 + seed_generator = Random(training_seed) seed_values = [] @@ -345,7 +354,7 @@ def get_seed(index): NUM_TEST_SAMPLES = 1000 reserved_inputs = set() test_accuracies = [] - if is_dfs: + if task == 'dfs': for backtrack_distance in [-1] + list(range(0, max_backtrack_distance + 1)): generator.set_seed(get_seed(1)) inputs,outputs,labels,_ = generator.generate_dfs_training_set(max_input_size, NUM_TEST_SAMPLES, reserved_inputs, backtrack_distance, False, False, True) @@ -361,6 +370,36 @@ def get_seed(index): import pdb; pdb.set_trace() print("Test accuracy = %.2f±%.2f, test loss = %f" % (test_acc, confidence_int, test_loss)) test_accuracies.append((test_acc, confidence_int, test_loss)) + elif task == 'si': + max_edges = (max_input_size - 2) // 6 + max_frontier_size = (max_edges + 1) // 2 + max_branch_size = max_edges + frontier_branches = [] + for frontier_size in range(1, max_frontier_size + 1): + for branch_size in range(1, max_branch_size + 1): + if frontier_size + branch_size > max_edges + 1: + continue + frontier_branches.append((frontier_size, branch_size)) + for frontier_size, branch_size in frontier_branches: + generator.set_seed(get_seed(1)) + inputs,outputs,labels,_ = generator.generate_si_training_set(max_input_size, NUM_TEST_SAMPLES, reserved_inputs, frontier_size, branch_size, False, True) + test_acc,test_loss,predictions = evaluate_model(model, inputs, outputs) + confidence_int = binomial_confidence_int(test_acc, NUM_TEST_SAMPLES) + predictions = np.array(predictions.cpu()) + '''print("Mistaken inputs:") + incorrect_indices,_ = np.nonzero(np.take_along_axis(outputs, predictions[:,None], axis=1) == 0) + np.set_printoptions(threshold=10_000) + for incorrect_index in incorrect_indices: + print_graph(inputs[incorrect_index, :]) + print("Expected answer: {}, predicted answer: {} (label: {})\n".format(np.nonzero(outputs[incorrect_index])[0], predictions[incorrect_index], labels[incorrect_index]))''' + selection_inputs = (inputs[:,-1] == PATH_PREFIX_TOKEN) + selection_acc = np.sum(np.take_along_axis(outputs[selection_inputs], predictions[selection_inputs,None], axis=1)) / np.sum(selection_inputs) + inference_acc = np.sum(np.take_along_axis(outputs[~selection_inputs], predictions[~selection_inputs,None], axis=1)) / np.sum(~selection_inputs) + selection_confidence_int = binomial_confidence_int(selection_acc, np.sum(selection_inputs)) + inference_confidence_int = binomial_confidence_int(inference_acc, np.sum(~selection_inputs)) + print("(%u,%u) Test accuracy = %.2f±%.2f, test loss = %f, selection accuracy = %.2f±%.2f, inference accuracy = %.2f±%.2f" % (frontier_size, branch_size, test_acc, confidence_int, test_loss, selection_acc, selection_confidence_int, inference_acc, inference_confidence_int)) + #import pdb; pdb.set_trace() + test_accuracies.append((test_acc, confidence_int, test_loss)) elif star_distribution: for spoke_length in range(1, max_lookahead + 1): max_spoke_count = ((max_input_size - 5) // 3 - 1) // spoke_length diff --git a/generator.cpp b/generator.cpp index 98f5213..32959b0 100644 --- a/generator.cpp +++ b/generator.cpp @@ -181,7 +181,7 @@ bool has_cycles(array& vertices) { return false; } -bool generate_graph_with_lookahead(array& vertices, node*& start, node*& end, unsigned int num_vertices, unsigned int max_num_parents, unsigned int max_vertex_id, unsigned int lookahead, unsigned int num_paths) +bool generate_graph_with_lookahead(array& vertices, node*& start, node*& end, unsigned int num_vertices, unsigned int max_num_parents, unsigned int max_vertex_id, unsigned int lookahead, unsigned int num_paths, unsigned int max_prefix_vertices) { num_vertices = std::max(std::max(2u, num_vertices), 1 + num_paths * lookahead); @@ -215,7 +215,7 @@ bool generate_graph_with_lookahead(array& vertices, node*& start, node*& e } } - unsigned int num_prefix_vertices = randrange(num_vertices - index + 1); + unsigned int num_prefix_vertices = randrange(min(max_prefix_vertices + 1, num_vertices - index + 1)); node* prev_vertex = &vertices[0]; for (unsigned int i = 0; i < num_prefix_vertices; i++) { vertices[index].children.add(prev_vertex); @@ -326,13 +326,13 @@ bool generate_graph_with_lookahead(array& vertices, node*& start, node*& e return true; } -bool generate_example(array& vertices, node*& start, node*& end, array>& paths, unsigned int num_vertices, unsigned int max_num_parents, unsigned int max_vertex_id, bool get_shortest_paths, int lookahead, unsigned int num_paths) +bool generate_example(array& vertices, node*& start, node*& end, array>& paths, unsigned int num_vertices, unsigned int max_num_parents, unsigned int max_vertex_id, bool get_shortest_paths, int lookahead, unsigned int num_paths, unsigned int max_prefix_vertices) { if (lookahead == -1) { if (!generate_graph(vertices, start, end, num_vertices, max_num_parents, max_vertex_id)) return false; } else { - if (!generate_graph_with_lookahead(vertices, start, end, num_vertices, max_num_parents, max_vertex_id, lookahead, num_paths)) + if (!generate_graph_with_lookahead(vertices, start, end, num_vertices, max_num_parents, max_vertex_id, lookahead, num_paths, max_prefix_vertices)) return false; } @@ -507,7 +507,7 @@ bool has_path(const node* start, const node* end) return false; } -py::tuple generate_training_set(const unsigned int max_input_size, const uint64_t dataset_size, const int max_lookahead, const unsigned int max_edges, const py::object& reserved_inputs, const int distance_from_start, const bool quiet=false) +py::tuple generate_training_set(const unsigned int max_input_size, const uint64_t dataset_size, const int max_lookahead, const unsigned int max_edges, const py::object& reserved_inputs, const int distance_from_start, const int max_prefix_vertices, const bool quiet=false) { const unsigned int QUERY_PREFIX_TOKEN = (max_input_size-5) / 3 + 4; const unsigned int PADDING_TOKEN = (max_input_size-5) / 3 + 3; @@ -553,7 +553,7 @@ py::tuple generate_training_set(const unsigned int max_input_size, const uint64_ while (true) { if (max_lookahead == -1) { unsigned int num_vertices = randrange(3, (max_input_size - 5) / 3); - if (!generate_example(g, start, end, paths, num_vertices, 4, (max_input_size - 5) / 3, true, -1, 0)) { + if (!generate_example(g, start, end, paths, num_vertices, 4, (max_input_size - 5) / 3, true, -1, 0, max_prefix_vertices == -1 ? max_input_size : max_prefix_vertices)) { for (node& n : g) core::free(n); for (array& a : paths) core::free(a); g.length = 0; paths.length = 0; @@ -575,7 +575,7 @@ py::tuple generate_training_set(const unsigned int max_input_size, const uint64_ } unsigned int num_vertices = std::min(std::min(lookahead * num_paths + 1 + randrange(0, 6), (max_input_size-5) / 3), max_edges + 1); - if (!generate_example(g, start, end, paths, num_vertices, 4, (max_input_size - 5) / 3, true, lookahead, num_paths)) { + if (!generate_example(g, start, end, paths, num_vertices, 4, (max_input_size - 5) / 3, true, lookahead, num_paths, max_prefix_vertices == -1 ? max_input_size : max_prefix_vertices)) { for (node& n : g) core::free(n); for (array& a : paths) core::free(a); g.length = 0; paths.length = 0; @@ -726,7 +726,7 @@ py::array_t lookahead_histogram(const unsigned int array> paths(8); while (true) { unsigned int num_vertices = randrange(3, (max_input_size - 5) / 3); - if (!generate_example(g, start, end, paths, num_vertices, 4, (max_input_size - 5) / 3, true, -1, 0)) { + if (!generate_example(g, start, end, paths, num_vertices, 4, (max_input_size - 5) / 3, true, -1, 0, -1)) { for (node& n : g) core::free(n); for (array& a : paths) core::free(a); g.length = 0; paths.length = 0; @@ -848,7 +848,7 @@ py::tuple generate_reachable_training_set(const unsigned int max_input_size, con } unsigned int num_vertices = std::min(std::min(lookahead * num_paths + 1 + randrange(0, 6), (max_input_size-5) / 3), max_edges + 1); - if (!generate_example(g, start, end, paths, num_vertices, 4, max_vertex_id, true, lookahead, num_paths)) { + if (!generate_example(g, start, end, paths, num_vertices, 4, max_vertex_id, true, lookahead, num_paths, -1)) { for (node& n : g) core::free(n); for (array& a : paths) core::free(a); g.length = 0; paths.length = 0; @@ -1741,32 +1741,38 @@ bool generate_si_example(array& vertices, const node*& start, const node*& pair entry = reachability_stack.pop(); node* next = entry.key; const node* parent = entry.value; - if (reverse_ptrs.contains(next)) + if (next->id > end_index || reverse_ptrs.contains(next)) continue; reverse_ptrs.put(next, parent); for (node* child : next->children) if (!reachability_stack.contains(make_pair(child, next))) reachability_stack.add(make_pair(child, next)); } + const node* current; + const node* parent; if (reverse_ptrs.contains(&vertices[end_index])) { - /* make sure none of the edges on the path from `start` to `end` are removable */ - const node* current = &vertices[end_index]; - const node* parent = reverse_ptrs.get(current); - while (true) { - pair entry = make_pair(parent->id, current->id); - unsigned int index = removable_edges.index_of(entry); - if (index != removable_edges.length) - removable_edges.remove(index); - if (path.contains(parent) || parent == start) - break; - current = parent; - parent = reverse_ptrs.get(current); - } + current = &vertices[end_index]; + parent = reverse_ptrs.get(current); } else { + reverse_ptrs.put(&vertices[start_index], nullptr); node* new_parent = choice(reverse_ptrs.keys, reverse_ptrs.size); new_parent->children.add(&vertices[end_index]); vertices[end_index].parents.add(new_parent); total_edge_count++; new_edge_count++; + + current = &vertices[end_index]; + parent = new_parent; + } + /* make sure none of the edges on the path from `start` to `end` are removable */ + while (true) { + pair entry = make_pair(parent->id, current->id); + unsigned int index = removable_edges.index_of(entry); + if (index != removable_edges.length) + removable_edges.remove(index); + if (parent == start) + break; + current = parent; + parent = reverse_ptrs.get(current); } /* remove edges to avoid generating a graph with too many edges */ @@ -2007,7 +2013,11 @@ py::tuple generate_si_training_set(const unsigned int max_input_size, const uint const node* current_node = &g[current_node_index]; unsigned int branch_size = current_node->children.length; - bool is_selection_step = (randrange(2) == 0); + bool is_selection_step; + if (path.length == 0) + is_selection_step = false; + else + is_selection_step = (randrange(2) == 0); if (3*(path.length/2) + (is_selection_step ? 1 : 2) > 3*(max_edges - 1) + 1) { /* we have just barely too many edges */ for (node& n : g) core::free(n); diff --git a/train.py b/train.py index 9b43547..3391c89 100644 --- a/train.py +++ b/train.py @@ -36,7 +36,7 @@ def build_module(name): sys.stdout = old_stdout python_extension_suffix = sysconfig.get_config_var("EXT_SUFFIX") - command = f"g++ -Ofast -fno-stack-protector -Wall -Wpedantic -shared -fPIC {includes} -I. {name}.cpp -o {name}{python_extension_suffix}" + command = f"g++ -Ofast -DNDEBUG -fno-stack-protector -Wall -Wpedantic -shared -fPIC {includes} -I. {name}.cpp -o {name}{python_extension_suffix}" print(command) if os.system(command) != 0: print(f"ERROR: Unable to compile `{name}.cpp`.") @@ -153,7 +153,7 @@ def get_descendants(node): queue.append(child) return descendants -def generate_graph_with_lookahead(num_vertices, max_num_parents, max_vertex_id, lookahead, num_paths): +def generate_graph_with_lookahead(num_vertices, max_num_parents, max_vertex_id, lookahead, num_paths, max_prefix_vertices): num_vertices = max(2, num_vertices, 1 + num_paths * lookahead) vertices = [] @@ -179,7 +179,7 @@ def generate_graph_with_lookahead(num_vertices, max_num_parents, max_vertex_id, vertices[index - 1].children.append(vertices[index]) index += 1 - num_prefix_vertices = randrange(num_vertices - index + 1) + num_prefix_vertices = randrange(min(max_prefix_vertices + 1, num_vertices - index + 1)) prev_vertex = vertices[0] for i in range(num_prefix_vertices): vertices[index].children.append(prev_vertex) @@ -300,7 +300,7 @@ def compute_paths(graph, start, end, get_shortest_paths): return paths -def generate_example(num_vertices, max_num_parents, max_vertex_id, get_shortest_paths=True, lookahead=None, num_paths=None): +def generate_example(num_vertices, max_num_parents, max_vertex_id, get_shortest_paths=True, lookahead=None, num_paths=None, max_prefix_vertices=None): if lookahead == None: graph = generate_graph(num_vertices, max_num_parents, max_vertex_id) @@ -311,7 +311,7 @@ def generate_example(num_vertices, max_num_parents, max_vertex_id, get_shortest_ if end != start: break else: - graph, start, end = generate_graph_with_lookahead(num_vertices, max_num_parents, max_vertex_id, lookahead, num_paths) + graph, start, end = generate_graph_with_lookahead(num_vertices, max_num_parents, max_vertex_id, lookahead, num_paths, max_prefix_vertices) if graph == None: return None, None, None, None @@ -364,7 +364,7 @@ def generate_star_graph(num_spokes, spoke_length, max_vertex_id): def binomial_confidence_int(p, n): return 1.96 * np.sqrt(p * (1.0 - p) / n) -def generate_star_graph_data(max_input_size, num_spokes, spoke_length, num_samples=1000): +def generate_star_graph_data(max_input_size, num_spokes, spoke_length, num_samples=1000, reserved_inputs=None, uniform=False): QUERY_PREFIX_TOKEN = (max_input_size-5) // 3 + 4 PADDING_TOKEN = (max_input_size-5) // 3 + 3 EDGE_PREFIX_TOKEN = (max_input_size-5) // 3 + 2 @@ -373,8 +373,20 @@ def generate_star_graph_data(max_input_size, num_spokes, spoke_length, num_sampl total_predictions = 0 inputs = np.empty((num_samples, max_input_size), dtype=np.int64) outputs = np.empty(num_samples, dtype=np.int64) + num_collisions = 0 + spoke_lengths = [] + if uniform: + for spoke_len in range(1, spoke_length + 1): + max_spoke_count = ((max_input_size - 5) // 3 - 1) // spoke_len + for spoke_count in range(1, max_spoke_count + 1): + if spoke_count > num_spokes: + continue + spoke_lengths.append((spoke_len, spoke_count)) + else: + spoke_lengths = [(spoke_length, num_spokes)] while total_predictions < num_samples: - g, start, end = generate_star_graph(num_spokes, spoke_length, (max_input_size - 5) // 3) + spoke_len,spoke_count = choice(spoke_lengths) + g, start, end = generate_star_graph(spoke_count, spoke_len, (max_input_size - 5) // 3) paths = compute_paths(g, start, end, get_shortest_paths=True) if paths == None: @@ -387,6 +399,9 @@ def generate_star_graph_data(max_input_size, num_spokes, spoke_length, num_sampl prefix.extend([QUERY_PREFIX_TOKEN, start.id, end.id, PATH_PREFIX_TOKEN]) prefix.append(start.id) + if reserved_inputs != None and tuple(prefix) in reserved_inputs: + num_collisions += 1 + continue input = [PADDING_TOKEN] * (max_input_size - len(prefix)) + prefix inputs[total_predictions,:] = input outputs[total_predictions] = paths[0][1].id @@ -394,9 +409,9 @@ def generate_star_graph_data(max_input_size, num_spokes, spoke_length, num_sampl if total_predictions == num_samples: break - return inputs, outputs + return inputs, outputs, num_collisions -def generate_eval_data(max_input_size, min_path_length=2, distance_from_start=-1, distance_from_end=-1, lookahead_steps=None, num_paths_at_fork=None, num_samples=1000): +def generate_eval_data(max_input_size, min_path_length=2, distance_from_start=-1, distance_from_end=-1, lookahead_steps=None, num_paths_at_fork=None, num_samples=1000, max_prefix_vertices=None): QUERY_PREFIX_TOKEN = (max_input_size-5) // 3 + 4 PADDING_TOKEN = (max_input_size-5) // 3 + 3 EDGE_PREFIX_TOKEN = (max_input_size-5) // 3 + 2 @@ -431,7 +446,7 @@ def generate_eval_data(max_input_size, min_path_length=2, distance_from_start=-1 num_vertices = min(lookahead_steps * num_paths + 1 + randrange(0, 6), (max_input_size - 5) // 3) else: num_paths = None - g, start, end, paths = generate_example(num_vertices, 4, (max_input_size - 5) // 3, get_shortest_paths=False, lookahead=lookahead_steps, num_paths=num_paths) + g, start, end, paths = generate_example(num_vertices, 4, (max_input_size - 5) // 3, get_shortest_paths=False, lookahead=lookahead_steps, num_paths=num_paths, max_prefix_vertices=(max_input_size if max_prefix_vertices == None else max_prefix_vertices)) if paths != None and min([len(path) for path in paths]) > (min(lookahead_steps, min_path_length) if lookahead_steps != None else min_path_length): break @@ -661,14 +676,18 @@ def train(max_input_size, dataset_size, distribution, max_lookahead, seed_value, print('ERROR: Curriculum learning is only supported with streaming training (i.e. dataset_size = -1).') stdout.flush() return - if distribution == "crafted" and max_lookahead == None: - print('ERROR: Crafted training distribution is selected but `max_lookhead` argument is missing.') + if distribution in ("crafted", "crafted_no_prefix", "star") and max_lookahead == None: + print('ERROR: Crafted or star training distribution is selected but `max_lookahead` argument is missing.') stdout.flush() return if distribution == "simple" and max_lookahead != None: print('ERROR: `max_lookahead` is not supported with the simple training distribution.') stdout.flush() return + if distribution in ("crafted_no_prefix", "star") and task != "search": + print('ERROR: Distributions `crafted_no_prefix` and `star` are only supported with task `search`.') + stdout.flush() + return if max_lookahead == None: max_lookahead = -1 @@ -691,6 +710,7 @@ def train(max_input_size, dataset_size, distribution, max_lookahead, seed_value, if backtrack_distance == 8: eval_inputs, eval_outputs = inputs, outputs elif task == 'si': + NUM_TEST_SAMPLES = 1000 max_edges = (max_input_size - 2) // 6 max_frontier_size = (max_edges + 1) // 2 max_branch_size = max_edges @@ -708,29 +728,50 @@ def train(max_input_size, dataset_size, distribution, max_lookahead, seed_value, print('Reserving OOD test data for frontier_size = {}, branch_size = {}'.format(frontier_size, branch_size)) stdout.flush() - inputs,outputs,_,_ = generator.generate_si_training_set(max_input_size, 10000 if (frontier_size == 4 and branch_size == 4) else NUM_TEST_SAMPLES, reserved_inputs, frontier_size, branch_size, False, True) + inputs,outputs,_,_ = generator.generate_si_training_set(max_input_size, NUM_TEST_SAMPLES, reserved_inputs, frontier_size, branch_size, False, True) print('Done. Throughput: {} examples/s'.format(NUM_TEST_SAMPLES / (time.perf_counter() - gen_eval_start_time))) for i in range(inputs.shape[0]): reserved_inputs.add(tuple([x for x in inputs[i,:] if x != PADDING_TOKEN])) if frontier_size == 4 and branch_size == 4: eval_inputs, eval_outputs = inputs, outputs elif task == 'search': - max_test_lookahead = ((max_input_size - 5) // 3 - 1) // 2 - dist_from_start = 1 if add_padding else -1 - for lookahead in list(range(1, max_test_lookahead + 1)) + [None]: - gen_eval_start_time = time.perf_counter() - setstate(random_state) - np.random.set_state(np_random_state) - torch.set_rng_state(torch_random_state) - - print('Reserving OOD test data for lookahead = {}'.format(lookahead)) - stdout.flush() - inputs,outputs = generate_eval_data(max_input_size, min_path_length=2, distance_from_start=dist_from_start, distance_from_end=-1, lookahead_steps=lookahead, num_paths_at_fork=None, num_samples=NUM_TEST_SAMPLES) - print('Done. Throughput: {} examples/s'.format(NUM_TEST_SAMPLES / (time.perf_counter() - gen_eval_start_time))) - for i in range(inputs.shape[0]): - reserved_inputs.add(tuple([x for x in inputs[i,:] if x != PADDING_TOKEN])) - if lookahead == None: - eval_inputs, eval_outputs = inputs, outputs + if distribution in ('crafted', 'crafted_no_prefix'): + max_test_lookahead = ((max_input_size - 5) // 3 - 1) // 2 + dist_from_start = 1 if add_padding else -1 + for lookahead in list(range(1, max_test_lookahead + 1)) + [None]: + gen_eval_start_time = time.perf_counter() + setstate(random_state) + np.random.set_state(np_random_state) + torch.set_rng_state(torch_random_state) + + print('Reserving OOD test data for lookahead = {}'.format(lookahead)) + stdout.flush() + if distribution == 'crafted': + inputs,outputs = generate_eval_data(max_input_size, min_path_length=2, distance_from_start=dist_from_start, distance_from_end=-1, lookahead_steps=lookahead, num_paths_at_fork=None, num_samples=NUM_TEST_SAMPLES, max_prefix_vertices=None) + elif distribution == 'crafted_no_prefix': + inputs,outputs = generate_eval_data(max_input_size, min_path_length=2, distance_from_start=dist_from_start, distance_from_end=-1, lookahead_steps=lookahead, num_paths_at_fork=None, num_samples=NUM_TEST_SAMPLES, max_prefix_vertices=0) + print('Done. Throughput: {} examples/s'.format(NUM_TEST_SAMPLES / (time.perf_counter() - gen_eval_start_time))) + for i in range(inputs.shape[0]): + reserved_inputs.add(tuple([x for x in inputs[i,:] if x != PADDING_TOKEN])) + if lookahead == None: + eval_inputs, eval_outputs = inputs, outputs + elif distribution == 'star': + for spoke_length in range(1, max_lookahead + 1): + max_spoke_count = ((max_input_size - 5) // 3 - 1) // spoke_length + for num_spokes in range(1, max_spoke_count + 1): + gen_eval_start_time = time.perf_counter() + setstate(random_state) + np.random.set_state(np_random_state) + torch.set_rng_state(torch_random_state) + + print('Reserving OOD test data for spoke_length = {} and num_spokes = {}'.format(spoke_length, num_spokes)) + stdout.flush() + inputs,outputs,_ = generate_star_graph_data(max_input_size, num_spokes, spoke_length, num_samples=NUM_TEST_SAMPLES) + print('Done. Throughput: {} examples/s'.format(NUM_TEST_SAMPLES / (time.perf_counter() - gen_eval_start_time))) + for i in range(inputs.shape[0]): + reserved_inputs.add(tuple([x for x in inputs[i,:] if x != PADDING_TOKEN])) + if spoke_length == 4 and num_spokes == 3: + eval_inputs, eval_outputs = inputs, outputs else: print('ERROR: Unrecognized task "{}".'.format(task)) stdout.flush() @@ -811,6 +852,8 @@ def train(max_input_size, dataset_size, distribution, max_lookahead, seed_value, filename += '_looped' if task != 'search': filename += '_' + task + if distribution != 'crafted': + filename += '_' + distribution.replace('_', '-') if nhead != 1: filename += '_nhead' + str(nhead) if warm_up != 0: @@ -940,6 +983,7 @@ def process_data(self, start): current = start worker_info = torch.utils.data.get_worker_info() worker_id = worker_info.id + max_prefix_vertices = (0 if distribution == 'crafted_no_prefix' else max_input_size) while True: worker_start_time = time.perf_counter() new_seed = get_seed(current) @@ -954,7 +998,15 @@ def process_data(self, start): elif task == 'si': inputs, outputs, labels, num_collisions = generator.generate_si_training_set(max_input_size, BATCH_SIZE, reserved_inputs, max_frontier_size, max_branch_size, True, True) else: - inputs, outputs, labels, num_collisions = generator.generate_training_set(max_input_size, BATCH_SIZE, self.lookahead, self.max_edges, reserved_inputs, dist_from_start, True) + if distribution == 'star': + max_spoke_length = self.lookahead + max_spoke_count = ((max_input_size - 5) // 3 - 1) // 1 + inputs, labels, num_collisions = generate_star_graph_data(max_input_size, num_spokes=max_spoke_count, spoke_length=max_spoke_length, num_samples=BATCH_SIZE, reserved_inputs=reserved_inputs, uniform=True) + ntokens = (max_input_size - 5) // 3 + 5 + outputs = np.zeros((BATCH_SIZE, ntokens)) + outputs[:,labels] = 1.0 + else: + inputs, outputs, labels, num_collisions = generator.generate_training_set(max_input_size, BATCH_SIZE, self.lookahead, self.max_edges, reserved_inputs, dist_from_start, max_prefix_vertices, True) if num_collisions != 0: with self.collisions_lock: self.total_collisions.value += num_collisions @@ -1196,7 +1248,7 @@ def parse_bool_arg(v): parser.add_argument("--curriculum", type=str, required=True, choices=["y", "n", "layerbylayer", "layerbylayer2"]) parser.add_argument("--looped", type=parse_bool_arg, default=False) parser.add_argument("--task", type=str, default="search", choices=["search", "dfs", "si"]) - parser.add_argument("--distribution", type=str, default="crafted", choices=["simple", "crafted"]) + parser.add_argument("--distribution", type=str, default="crafted", choices=["simple", "crafted", "crafted_no_prefix", "star"]) parser.add_argument("--warm-up", type=int, default=0, required=False) parser.add_argument("--batch-size", type=int, default=2**8, required=False) parser.add_argument("--learning-rate", type=float, default=1.0e-5, required=False)