From c855b48345e329b9ce0f2a09415c04b4b094d4d9 Mon Sep 17 00:00:00 2001 From: Francisco Acosta Date: Thu, 18 Apr 2024 18:11:45 -0700 Subject: [PATCH] add wandb integration, update pyproject.toml and .gitignore --- .gitignore | 6 +++++- .../models/xu_rnn/configs/rnn_isometry.py | 4 ++-- .../models/xu_rnn/experiment.py | 17 +++++++++++++++-- .../grid-cells-curvature/models/xu_rnn/main.py | 2 +- pyproject.toml | 5 +++++ 5 files changed, 28 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index 9d84df4..74ef776 100644 --- a/.gitignore +++ b/.gitignore @@ -23,7 +23,11 @@ neurometry/datasets/rnn_grid_cells/Dual agent path integration disjoint PCs/* neurometry/datasets/rnn_grid_cells/Single agent path integration/* # Wandb files -wandb/* +*wandb/* +*logs/* + +neurometry/curvature/grid-cells-curvature/models/xu_rnn/logs/* +neurometry/curvature/grid-cells-curvature/models/xu_rnn/wandb/* # Result files diff --git a/neurometry/curvature/grid-cells-curvature/models/xu_rnn/configs/rnn_isometry.py b/neurometry/curvature/grid-cells-curvature/models/xu_rnn/configs/rnn_isometry.py index 1227882..869577a 100644 --- a/neurometry/curvature/grid-cells-curvature/models/xu_rnn/configs/rnn_isometry.py +++ b/neurometry/curvature/grid-cells-curvature/models/xu_rnn/configs/rnn_isometry.py @@ -13,11 +13,11 @@ def get_config(): # training config config.train = d( - num_steps_train=20, #100000 + num_steps_train=25000, #100000 lr=0.006, lr_decay_from=10000, steps_per_logging=20, - steps_per_large_logging=5, #500 + steps_per_large_logging=500, #500 steps_per_integration=2000, norm_v=True, positive_v=True, diff --git a/neurometry/curvature/grid-cells-curvature/models/xu_rnn/experiment.py b/neurometry/curvature/grid-cells-curvature/models/xu_rnn/experiment.py index 2ad475a..56704f2 100644 --- a/neurometry/curvature/grid-cells-curvature/models/xu_rnn/experiment.py +++ b/neurometry/curvature/grid-cells-curvature/models/xu_rnn/experiment.py @@ -17,12 +17,16 @@ import model as model import utils import pickle +import wandb class Experiment: def __init__(self, config: ml_collections.ConfigDict, device): + self.config = config self.device = device + wandb.init(project='grid-cell-rnns', entity='bioshape-lab', config=config.to_dict()) + # initialize models logging.info("==== initialize model ====") self.model_config = model.GridCellConfig(**config.model) @@ -118,6 +122,7 @@ def train_and_evaluate(self, workdir): if step % config.steps_per_logging == 0 or step == 1: train_metrics = utils.average_appended_metrics(train_metrics) writer.write_scalars(step, train_metrics) + wandb.log({key: value for key, value in train_metrics.items()}, step=step) train_metrics = [] if step % config.steps_per_large_logging == 0: @@ -131,7 +136,9 @@ def visualize(activations, name): activations = activations.data.cpu().detach().numpy() activations = activations.reshape( (-1, block_size, num_grid, num_grid))[:10, :10] - writer.write_images(step, {name: utils.draw_heatmap(activations)}) + images = utils.draw_heatmap(activations) + writer.write_images(step, {name: images}) + wandb.log({name: wandb.Image(images)}, step=step) visualize(self.model.encoder.v, 'v') visualize(self.model.decoder.u, 'u') @@ -172,11 +179,13 @@ def visualize(activations, name): heatmaps = heatmaps.cpu().detach().numpy()[None, ...] writer.write_images( step, {'vu_heatmap': utils.draw_heatmap(heatmaps)}) + wandb.log({'vu_heatmap': wandb.Image(utils.draw_heatmap(heatmaps))}, step=step) err = torch.mean(torch.sum((x_eval - x_pred) ** 2, dim=-1)) writer.write_scalars(step, {'pred_x': err.item()}) writer.write_scalars(step, {'error_fixed': error_fixed.item()}) writer.write_scalars(step, {'error_fixed_zero': error_fixed_zero.item()}) + wandb.log({'pred_x': err.item(), 'error_fixed': error_fixed.item(), 'error_fixed_zero': error_fixed_zero.item()}, step=step) if step % config.steps_per_integration == 0 or step == 1: # perform path integration @@ -193,6 +202,7 @@ def visualize(activations, name): writer.write_scalars(step, {'score': score.item()}) writer.write_scalars(step, {'scale': scale_tensor[0].item() * num_grid}) writer.write_scalars(step, {'scale_mean': torch.mean(scale_tensor).item() * num_grid}) + wandb.log({'score': score.item(), 'scale': scale_tensor[0].item() * num_grid, 'scale_mean': torch.mean(scale_tensor).item() * num_grid}, step=step) # for visualization if self.config.model.trans_type == 'nonlinear_simple': @@ -209,6 +219,7 @@ def visualize(activations, name): 'heatmaps': utils.draw_heatmap(outputs['heatmaps'][:, ::5]), } writer.write_images(step, images) + wandb.log({key: wandb.Image(value) for key, value in images.items()}, step=step) # for quantitative evaluation if self.config.model.trans_type == 'nonlinear_simple': @@ -218,6 +229,7 @@ def visualize(activations, name): err = utils.dict_to_numpy(outputs['err']) writer.write_scalars(step, err) + wandb.log({key: value for key, value in err.items()}, step=step) if step == config.num_steps_train: ckpt_dir = os.path.join(workdir, 'ckpt') @@ -300,8 +312,9 @@ def _save_checkpoint(self, step, ckpt_dir): if not tf.io.gfile.exists(model_dir): tf.io.gfile.makedirs(model_dir) model_filename = os.path.join(model_dir, 'checkpoint-step{}.pth'.format(step)) - torch.save(state, model_filename) logging.info("Saving model checkpoint: {} ...".format(model_filename)) + torch.save(state, model_filename) + wandb.save(model_filename) activations_dir = os.path.join(ckpt_dir, 'activations') if not tf.io.gfile.exists(activations_dir): diff --git a/neurometry/curvature/grid-cells-curvature/models/xu_rnn/main.py b/neurometry/curvature/grid-cells-curvature/models/xu_rnn/main.py index 855a1b7..b4065da 100644 --- a/neurometry/curvature/grid-cells-curvature/models/xu_rnn/main.py +++ b/neurometry/curvature/grid-cells-curvature/models/xu_rnn/main.py @@ -13,7 +13,7 @@ FLAGS = flags.FLAGS config_flags.DEFINE_config_file( "config", None, "Training configuration.", lock_config=True) -flags.DEFINE_string("workdir", "../logs", "Work unit directory.") +flags.DEFINE_string("workdir", "logs", "Work unit directory.") flags.mark_flags_as_required(["config"]) flags.DEFINE_string("mode", 'train', "train / visualize / integration / correction") diff --git a/pyproject.toml b/pyproject.toml index 1fd5a48..ba8193c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,11 @@ dependencies=[ "scikit-dimension", "umap-learn", "ripser", + "absl-py", + "ml-collections", + "tensowflow-cpu", + "clu", + "labml-helpers", "giotto-ph @ git+https://github.com/alibayeh/giotto-ph.git", "pyflagser @ git+https://github.com/alibayeh/pyflagser.git", "giotto-tda @ git+https://github.com/alibayeh/giotto-tda.git",