From 7a063f649cb54b46e66a77236402c124dd35d8c6 Mon Sep 17 00:00:00 2001 From: asaparov Date: Thu, 14 Mar 2024 11:29:20 -0400 Subject: [PATCH] Fixing bugs in the probe training code. --- train_probe.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/train_probe.py b/train_probe.py index d825a1e..4dcf552 100644 --- a/train_probe.py +++ b/train_probe.py @@ -1,10 +1,11 @@ -from random import seed, randrange +from random import seed, randrange, Random import numpy as np import torch from torch import nn, LongTensor, FloatTensor from torch.nn import BCEWithLogitsLoss, Sigmoid from torch.utils.data import DataLoader from Sophia import SophiaG +import multiprocessing def build_module(name): from os import system @@ -136,9 +137,10 @@ def evaluate_decoder(model, max_input_size): print("exact match accuracy = %.2f, partial match accuracy = %.2f" % (exact_match / num_examples, partial_match / num_examples)) if __name__ == "__main__": - seed(1) - torch.manual_seed(1) - np.random.seed(1) + seed_value = 1 + seed(seed_value) + torch.manual_seed(seed_value) + np.random.seed(seed_value) from sys import argv, exit if len(argv) != 2: @@ -153,7 +155,7 @@ def evaluate_decoder(model, max_input_size): else: device = torch.device('cuda') - tfm_model = torch.load(argv[1], map_location=device) + tfm_model, _, _, _ = torch.load(argv[1], map_location=device) model = TransformerProber(tfm_model, probe_layer=1) model.to(device) @@ -198,23 +200,20 @@ def process_data(self, start): worker_info = torch.utils.data.get_worker_info() worker_id = worker_info.id while True: - worker_start_time = time.perf_counter() new_seed = get_seed(current) generator.set_seed(new_seed) seed(new_seed) torch.manual_seed(new_seed) np.random.seed(new_seed) - generate_start_time = time.perf_counter() - inputs, outputs, num_collisions = generate_reachable_training_set(max_input_size, BATCH_SIZE, lookahead_steps, reserved_inputs, dist_from_start, reachable_distance, start_vertex_index) - generator.generate_training_set(max_input_size, BATCH_SIZE, max_lookahead, reserved_inputs, dist_from_start, True) + inputs, outputs, num_collisions = generator.generate_reachable_training_set(max_input_size, BATCH_SIZE, lookahead_steps, reserved_inputs, 1, reachable_distance, start_vertex_index) + import pdb; pdb.set_trace() if num_collisions != 0: with self.collisions_lock: self.total_collisions.value += num_collisions print('Total number of training examples generated that are in the test set: {}'.format(self.total_collisions.value)) stdout.flush() - worker_end_time = time.perf_counter() yield inputs, outputs current += NUM_DATA_WORKERS @@ -223,10 +222,12 @@ def __iter__(self): worker_id = worker_info.id return self.process_data(self.offset + worker_id) + epoch = 0 + BATCH_SIZE = 2 ** 11 iterable_dataset = StreamingDataset(epoch * STREAMING_BLOCK_SIZE // BATCH_SIZE) train_loader = DataLoader(iterable_dataset, batch_size=None, num_workers=NUM_DATA_WORKERS, pin_memory=True, prefetch_factor=8) - loss_func = BCELoss(reduction='mean') + loss_func = BCEWithLogitsLoss(reduction='mean') optimizer = SophiaG((p for p in model.parameters() if p.requires_grad), lr=1.0e-3) log_interval = 1 @@ -235,7 +236,6 @@ def __iter__(self): while True: for batch in train_loader: - batch_start_time = time.perf_counter() model.train() optimizer.zero_grad() @@ -243,10 +243,11 @@ def __iter__(self): input = input.to(device, non_blocking=True) output = output.to(device, non_blocking=True) - train_start_time = time.perf_counter() - transfer_time += train_start_time - batch_start_time - logits = model(input) + # only take the predictions on source vertices + import pdb; pdb.set_trace() + logits = logits[:,range(2,max_input_size-5,3),:] + output = output[:,range(2,max_input_size-5,3),:] loss_val = loss_func(logits, output) epoch_loss += loss_val.item()