Skip to content

Commit

Permalink
Fixing bugs in the probe training code.
Browse files Browse the repository at this point in the history
  • Loading branch information
asaparov committed Mar 14, 2024
1 parent df33c9b commit 7a063f6
Showing 1 changed file with 16 additions and 15 deletions.
31 changes: 16 additions & 15 deletions train_probe.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from random import seed, randrange
from random import seed, randrange, Random
import numpy as np
import torch
from torch import nn, LongTensor, FloatTensor
from torch.nn import BCEWithLogitsLoss, Sigmoid
from torch.utils.data import DataLoader
from Sophia import SophiaG
import multiprocessing

def build_module(name):
from os import system
Expand Down Expand Up @@ -136,9 +137,10 @@ def evaluate_decoder(model, max_input_size):
print("exact match accuracy = %.2f, partial match accuracy = %.2f" % (exact_match / num_examples, partial_match / num_examples))

if __name__ == "__main__":
seed(1)
torch.manual_seed(1)
np.random.seed(1)
seed_value = 1
seed(seed_value)
torch.manual_seed(seed_value)
np.random.seed(seed_value)

from sys import argv, exit
if len(argv) != 2:
Expand All @@ -153,7 +155,7 @@ def evaluate_decoder(model, max_input_size):
else:
device = torch.device('cuda')

tfm_model = torch.load(argv[1], map_location=device)
tfm_model, _, _, _ = torch.load(argv[1], map_location=device)
model = TransformerProber(tfm_model, probe_layer=1)
model.to(device)

Expand Down Expand Up @@ -198,23 +200,20 @@ def process_data(self, start):
worker_info = torch.utils.data.get_worker_info()
worker_id = worker_info.id
while True:
worker_start_time = time.perf_counter()
new_seed = get_seed(current)
generator.set_seed(new_seed)
seed(new_seed)
torch.manual_seed(new_seed)
np.random.seed(new_seed)

generate_start_time = time.perf_counter()
inputs, outputs, num_collisions = generate_reachable_training_set(max_input_size, BATCH_SIZE, lookahead_steps, reserved_inputs, dist_from_start, reachable_distance, start_vertex_index)
generator.generate_training_set(max_input_size, BATCH_SIZE, max_lookahead, reserved_inputs, dist_from_start, True)
inputs, outputs, num_collisions = generator.generate_reachable_training_set(max_input_size, BATCH_SIZE, lookahead_steps, reserved_inputs, 1, reachable_distance, start_vertex_index)
import pdb; pdb.set_trace()
if num_collisions != 0:
with self.collisions_lock:
self.total_collisions.value += num_collisions
print('Total number of training examples generated that are in the test set: {}'.format(self.total_collisions.value))
stdout.flush()

worker_end_time = time.perf_counter()
yield inputs, outputs
current += NUM_DATA_WORKERS

Expand All @@ -223,10 +222,12 @@ def __iter__(self):
worker_id = worker_info.id
return self.process_data(self.offset + worker_id)

epoch = 0
BATCH_SIZE = 2 ** 11
iterable_dataset = StreamingDataset(epoch * STREAMING_BLOCK_SIZE // BATCH_SIZE)
train_loader = DataLoader(iterable_dataset, batch_size=None, num_workers=NUM_DATA_WORKERS, pin_memory=True, prefetch_factor=8)

loss_func = BCELoss(reduction='mean')
loss_func = BCEWithLogitsLoss(reduction='mean')
optimizer = SophiaG((p for p in model.parameters() if p.requires_grad), lr=1.0e-3)

log_interval = 1
Expand All @@ -235,18 +236,18 @@ def __iter__(self):

while True:
for batch in train_loader:
batch_start_time = time.perf_counter()
model.train()
optimizer.zero_grad()

input, output = batch
input = input.to(device, non_blocking=True)
output = output.to(device, non_blocking=True)

train_start_time = time.perf_counter()
transfer_time += train_start_time - batch_start_time

logits = model(input)
# only take the predictions on source vertices
import pdb; pdb.set_trace()
logits = logits[:,range(2,max_input_size-5,3),:]
output = output[:,range(2,max_input_size-5,3),:]
loss_val = loss_func(logits, output)
epoch_loss += loss_val.item()

Expand Down

0 comments on commit 7a063f6

Please sign in to comment.