Skip to content

Commit

Permalink
Merge pull request #146 from geometric-intelligence/wandb
Browse files Browse the repository at this point in the history
add wandb integration, update pyproject.toml and .gitignore
  • Loading branch information
franciscoeacosta authored Apr 19, 2024
2 parents ccae50c + c855b48 commit b33eccb
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 6 deletions.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ neurometry/datasets/rnn_grid_cells/Dual agent path integration disjoint PCs/*
neurometry/datasets/rnn_grid_cells/Single agent path integration/*

# Wandb files
wandb/*
*wandb/*
*logs/*

neurometry/curvature/grid-cells-curvature/models/xu_rnn/logs/*
neurometry/curvature/grid-cells-curvature/models/xu_rnn/wandb/*


# Result files
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ def get_config():

# training config
config.train = d(
num_steps_train=20, #100000
num_steps_train=25000, #100000
lr=0.006,
lr_decay_from=10000,
steps_per_logging=20,
steps_per_large_logging=5, #500
steps_per_large_logging=500, #500
steps_per_integration=2000,
norm_v=True,
positive_v=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@
import model as model
import utils
import pickle
import wandb

class Experiment:
def __init__(self, config: ml_collections.ConfigDict, device):

self.config = config
self.device = device

wandb.init(project='grid-cell-rnns', entity='bioshape-lab', config=config.to_dict())

# initialize models
logging.info("==== initialize model ====")
self.model_config = model.GridCellConfig(**config.model)
Expand Down Expand Up @@ -118,6 +122,7 @@ def train_and_evaluate(self, workdir):
if step % config.steps_per_logging == 0 or step == 1:
train_metrics = utils.average_appended_metrics(train_metrics)
writer.write_scalars(step, train_metrics)
wandb.log({key: value for key, value in train_metrics.items()}, step=step)
train_metrics = []

if step % config.steps_per_large_logging == 0:
Expand All @@ -131,7 +136,9 @@ def visualize(activations, name):
activations = activations.data.cpu().detach().numpy()
activations = activations.reshape(
(-1, block_size, num_grid, num_grid))[:10, :10]
writer.write_images(step, {name: utils.draw_heatmap(activations)})
images = utils.draw_heatmap(activations)
writer.write_images(step, {name: images})
wandb.log({name: wandb.Image(images)}, step=step)

visualize(self.model.encoder.v, 'v')
visualize(self.model.decoder.u, 'u')
Expand Down Expand Up @@ -172,11 +179,13 @@ def visualize(activations, name):
heatmaps = heatmaps.cpu().detach().numpy()[None, ...]
writer.write_images(
step, {'vu_heatmap': utils.draw_heatmap(heatmaps)})
wandb.log({'vu_heatmap': wandb.Image(utils.draw_heatmap(heatmaps))}, step=step)

err = torch.mean(torch.sum((x_eval - x_pred) ** 2, dim=-1))
writer.write_scalars(step, {'pred_x': err.item()})
writer.write_scalars(step, {'error_fixed': error_fixed.item()})
writer.write_scalars(step, {'error_fixed_zero': error_fixed_zero.item()})
wandb.log({'pred_x': err.item(), 'error_fixed': error_fixed.item(), 'error_fixed_zero': error_fixed_zero.item()}, step=step)

if step % config.steps_per_integration == 0 or step == 1:
# perform path integration
Expand All @@ -193,6 +202,7 @@ def visualize(activations, name):
writer.write_scalars(step, {'score': score.item()})
writer.write_scalars(step, {'scale': scale_tensor[0].item() * num_grid})
writer.write_scalars(step, {'scale_mean': torch.mean(scale_tensor).item() * num_grid})
wandb.log({'score': score.item(), 'scale': scale_tensor[0].item() * num_grid, 'scale_mean': torch.mean(scale_tensor).item() * num_grid}, step=step)

# for visualization
if self.config.model.trans_type == 'nonlinear_simple':
Expand All @@ -209,6 +219,7 @@ def visualize(activations, name):
'heatmaps': utils.draw_heatmap(outputs['heatmaps'][:, ::5]),
}
writer.write_images(step, images)
wandb.log({key: wandb.Image(value) for key, value in images.items()}, step=step)

# for quantitative evaluation
if self.config.model.trans_type == 'nonlinear_simple':
Expand All @@ -218,6 +229,7 @@ def visualize(activations, name):

err = utils.dict_to_numpy(outputs['err'])
writer.write_scalars(step, err)
wandb.log({key: value for key, value in err.items()}, step=step)

if step == config.num_steps_train:
ckpt_dir = os.path.join(workdir, 'ckpt')
Expand Down Expand Up @@ -300,8 +312,9 @@ def _save_checkpoint(self, step, ckpt_dir):
if not tf.io.gfile.exists(model_dir):
tf.io.gfile.makedirs(model_dir)
model_filename = os.path.join(model_dir, 'checkpoint-step{}.pth'.format(step))
torch.save(state, model_filename)
logging.info("Saving model checkpoint: {} ...".format(model_filename))
torch.save(state, model_filename)
wandb.save(model_filename)

activations_dir = os.path.join(ckpt_dir, 'activations')
if not tf.io.gfile.exists(activations_dir):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
FLAGS = flags.FLAGS
config_flags.DEFINE_config_file(
"config", None, "Training configuration.", lock_config=True)
flags.DEFINE_string("workdir", "../logs", "Work unit directory.")
flags.DEFINE_string("workdir", "logs", "Work unit directory.")
flags.mark_flags_as_required(["config"])
flags.DEFINE_string("mode", 'train', "train / visualize / integration / correction")

Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ dependencies=[
"scikit-dimension",
"umap-learn",
"ripser",
"absl-py",
"ml-collections",
"tensowflow-cpu",
"clu",
"labml-helpers",
"giotto-ph @ git+https://github.com/alibayeh/giotto-ph.git",
"pyflagser @ git+https://github.com/alibayeh/pyflagser.git",
"giotto-tda @ git+https://github.com/alibayeh/giotto-tda.git",
Expand Down

0 comments on commit b33eccb

Please sign in to comment.