Skip to content

Commit

Permalink
Merge pull request #149 from geometric-intelligence/analyze_dual_agents
Browse files Browse the repository at this point in the history
Run topo vae, Analyze dual agents
  • Loading branch information
franciscoeacosta authored Apr 25, 2024
2 parents 99a8043 + 5dad96d commit 6852b45
Show file tree
Hide file tree
Showing 18 changed files with 11,101 additions and 6,515 deletions.
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# API KEY
neurometry/api_key.txt
*api_key.txt

neurometry/results/*
neurometry/wandb/*

neurometry/datasets/rnn_grid_cells/Dual agent path integration high res/*
neurometry/datasets/rnn_grid_cells/Single agent path integration high res/*


*viewer*
Expand Down
9 changes: 5 additions & 4 deletions neurometry/curvature/default_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
# WANDB API KEY
# Find it here: https://wandb.ai/authorize
# Story it in file: api_key.txt (without extra line break)
with open("api_key.txt") as f:
api_key_path = os.path.join(os.getcwd(), "curvature","api_key.txt")
with open(api_key_path) as f:
api_key = f.read()

# Directories
Expand Down Expand Up @@ -139,7 +140,7 @@

# Datasets
# dataset_name = ["s1_synthetic", "s2_synthetic"]
dataset_name = ["kb_synthetic"]
dataset_name = ["s1_synthetic"]
for one_dataset_name in dataset_name:
if one_dataset_name not in [
"s1_synthetic",
Expand All @@ -163,10 +164,10 @@

# Only used of dataset_name in ["s1_synthetic", "s2_synthetic", "t2_synthetic"]
n_times = [2500] # , 2000] # actual number of times is sqrt_ntimes ** 2
embedding_dim = [3, 10, 20, 30] # for s1 stopped at 5 (not done, but 3 was done)
embedding_dim = [5] # for s1 stopped at 5 (not done, but 3 was done)
geodesic_distortion_amp = [0.4]
# TODO: Add 0.03, possibly 0,000[1
noise_var = [0.1, 0.075, 0.05, 0.03, 0.01, 0.005, 0.001] # , 1e-2, 1e-1] 0.075, 0.1] #[
noise_var = [0.1] # , 1e-2, 1e-1] 0.075, 0.1] #[

# Only used if dataset_name == "grid_cells"
grid_scale = [1.0]
Expand Down
4 changes: 2 additions & 2 deletions neurometry/curvature/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import geomstats.backend as gs # noqa: E402

# import gph
from datasets.synthetic import ( # noqa: E402
from neurometry.curvature.datasets.synthetic import ( # noqa: E402
get_s1_synthetic_immersion,
get_s2_synthetic_immersion,
get_t2_synthetic_immersion,
Expand Down Expand Up @@ -106,7 +106,7 @@ def get_z_grid(config, n_grid_points=100):
z_grid = torch.cartesian_prod(thetas, phis)
return z_grid


#TODO: change instantiation of PullbackMetric to match latest geomstats version
def _compute_curvature(z_grid, immersion, dim, embedding_dim):
"""Compute mean curvature vector and its norm at each point."""
neural_metric = PullbackMetric(
Expand Down
2 changes: 2 additions & 0 deletions neurometry/curvature/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def elbo(x, x_mu, posterior_params, z, labels, config):

if config.gen_likelihood_type == "gaussian":
recon_loss = torch.mean((x - x_mu).pow(2))
else:
raise NotImplementedError

if config.dataset_name == "s1_synthetic":
recon_loss = recon_loss / (config.radius**2)
Expand Down
26 changes: 13 additions & 13 deletions neurometry/curvature/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,21 @@

# from ray.tune.integration.wandb import wandb_mixin
import torch
import train
import viz
import neurometry.curvature.train as train
import neurometry.curvature.viz as viz
import wandb
from ray import air, tune
from ray.tune.schedulers import AsyncHyperBandScheduler
from ray.tune.search.hyperopt import HyperOptSearch

os.environ["GEOMSTATS_BACKEND"] = "pytorch"
import datasets.utils # noqa: E402
import default_config # noqa: E402
import evaluate # noqa: E402
import neurometry.curvature.datasets.utils as utils # noqa: E402
import neurometry.curvature.default_config as default_config # noqa: E402
import neurometry.curvature.evaluate as evaluate # noqa: E402
import geomstats.backend as gs # noqa: E402
import models.klein_bottle_vae # noqa: E402
import models.neural_vae # noqa: E402
import models.toroidal_vae # noqa: E402
import neurometry.curvature.models.klein_bottle_vae as klein_bottle_vae # noqa: E402
import neurometry.curvature.models.neural_vae as neural_vae # noqa: E402
import neurometry.curvature.models.toroidal_vae as toroidal_vae # noqa: E402

# Required to make matplotlib figures in threads:
matplotlib.use("Agg")
Expand Down Expand Up @@ -262,7 +262,7 @@ def main_run(sweep_config):
wandb.run.name = run_name

# Load data, labels
dataset, labels, train_loader, test_loader = datasets.utils.load(wandb_config)
dataset, labels, train_loader, test_loader = utils.load(wandb_config)
data_n_times, data_dim = dataset.shape
wandb_config.update(
{
Expand Down Expand Up @@ -344,7 +344,7 @@ def create_model_and_train_test(config, train_loader, test_loader):
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
model = models.neural_vae.NeuralVAE(
model = neural_vae.NeuralVAE(
data_dim=data_dim,
latent_dim=config.latent_dim,
sftbeta=config.sftbeta,
Expand All @@ -360,7 +360,7 @@ def create_model_and_train_test(config, train_loader, test_loader):
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
model = models.toroidal_vae.ToroidalVAE(
model = toroidal_vae.ToroidalVAE(
data_dim=data_dim,
latent_dim=config.latent_dim,
sftbeta=config.sftbeta,
Expand All @@ -375,7 +375,7 @@ def create_model_and_train_test(config, train_loader, test_loader):
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
model = models.klein_bottle_vae.KleinBottleVAE(
model = klein_bottle_vae.KleinBottleVAE(
data_dim=data_dim,
latent_dim=config.latent_dim,
sftbeta=config.sftbeta,
Expand Down Expand Up @@ -587,4 +587,4 @@ def curvature_compute_plot_log(config, dataset, labels, model):
plt.close("all")


main()
#main()
2 changes: 1 addition & 1 deletion neurometry/curvature/models/klein_bottle_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import geomstats.backend as gs
import torch
from hyperspherical.distributions import VonMisesFisher
from neurometry.curvature.hyperspherical.distributions.von_mises_fisher import VonMisesFisher
from torch.nn import functional as F


Expand Down
2 changes: 1 addition & 1 deletion neurometry/curvature/models/neural_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

import torch
from hyperspherical.distributions import VonMisesFisher
from neurometry.curvature.hyperspherical.distributions.von_mises_fisher import VonMisesFisher
from torch.distributions.normal import Normal
from torch.nn import functional as F

Expand Down
2 changes: 1 addition & 1 deletion neurometry/curvature/models/toroidal_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import geomstats.backend as gs
import torch
from hyperspherical.distributions import VonMisesFisher
from neurometry.curvature.hyperspherical.distributions.von_mises_fisher import VonMisesFisher
from torch.nn import functional as F


Expand Down
2 changes: 1 addition & 1 deletion neurometry/curvature/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import copy

import losses
import neurometry.curvature.losses as losses
import torch
import wandb

Expand Down
82 changes: 31 additions & 51 deletions neurometry/datasets/load_rnn_grid_cells.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,60 +12,31 @@
utils,
)

# Loading single agent model

# parent_dir = os.getcwd() + "/datasets/rnn_grid_cells/"

parent_dir = "/scratch/facosta/rnn_grid_cells/"


single_model_folder = "Single agent path integration/Seed 1 weight decay 1e-06/"
single_model_parameters = "steps_20_batch_200_RNN_4096_relu_rf_012_DoG_True_periodic_False_lr_00001_weight_decay_1e-06/"


dual_model_folder = (
"Dual agent path integration disjoint PCs/Seed 1 weight decay 1e-06/"
)
dual_model_parameters = "steps_20_batch_200_RNN_4096_relu_rf_012_DoG_True_periodic_False_lr_00001_weight_decay_1e-06/"


def load_activations(epochs, version="single", verbose=True):
def load_activations(epochs, file_path, version="single", verbose=True, save = True):
activations = []
rate_maps = []
state_points = []
positions = []
g_s = []

if version == "single":
activations_dir = (
parent_dir + single_model_folder + single_model_parameters + "activations/"
)
elif version == "dual":
activations_dir = (
parent_dir + dual_model_folder + dual_model_parameters + "activations/"
)
activations_dir = os.path.join(file_path, "activations")

random.seed(0)
for epoch in epochs:
activations_epoch_path = (
activations_dir + f"activations_{version}_agent_epoch_{epoch}.npy"
)
rate_map_epoch_path = (
activations_dir + f"rate_map_{version}_agent_epoch_{epoch}.npy"
)
positions_epoch_path = (
activations_dir + f"positions_{version}_agent_epoch_{epoch}.npy"
)

if (
os.path.exists(activations_epoch_path)
and os.path.exists(rate_map_epoch_path)
and os.path.exists(positions_epoch_path)
):
activations_epoch_path = os.path.join(activations_dir, f"activations_{version}_agent_epoch_{epoch}.npy")
rate_map_epoch_path = os.path.join(activations_dir, f"rate_map_{version}_agent_epoch_{epoch}.npy")
positions_epoch_path = os.path.join(activations_dir, f"positions_{version}_agent_epoch_{epoch}.npy")
gs_epoch_path = os.path.join(activations_dir, f"g_{version}_agent_epoch_{epoch}.npy")

if os.path.exists(activations_epoch_path) and os.path.exists(
rate_map_epoch_path
) and os.path.exists(positions_epoch_path) and os.path.exists(gs_epoch_path):
activations.append(np.load(activations_epoch_path))
rate_maps.append(np.load(rate_map_epoch_path))
positions.append(np.load(positions_epoch_path))
g_s.append(np.load(gs_epoch_path))
if verbose:
print(f"Epoch {epoch} found!")
print(f"Epoch {epoch} found.")
else:
print(f"Epoch {epoch} not found. Loading ...")
parser = config.parser
Expand All @@ -75,22 +46,32 @@ def load_activations(epochs, version="single", verbose=True):
(
activations_single_agent,
rate_map_single_agent,
g_single_agent,
positions_single_agent,
) = single_agent_activity.main(options, epoch=epoch)
) = single_agent_activity.main(options, file_path, epoch=epoch)
activations.append(activations_single_agent)
rate_maps.append(rate_map_single_agent)
positions.append(positions_single_agent)
g_s.append(g_single_agent)
elif version == "dual":
activations_dual_agent, rate_map_dual_agent, positions_dual_agent = (
dual_agent_activity.main(options, epoch=epoch)
)
activations_dual_agent, rate_map_dual_agent, g_dual_agent, positions_dual_agent = dual_agent_activity.main(
options, file_path, epoch=epoch)
activations.append(activations_dual_agent)
rate_maps.append(rate_map_dual_agent)
positions.append(positions_dual_agent)
print(len(activations))
g_s.append(g_dual_agent)

if save:
np.save(activations_epoch_path, activations[-1])
np.save(rate_map_epoch_path, rate_maps[-1])
np.save(positions_epoch_path, positions[-1])
np.save(gs_epoch_path, g_s[-1])

state_points_epoch = activations[-1].reshape(activations[-1].shape[0], -1)
state_points.append(state_points_epoch)



if verbose:
print(f"Loaded epochs {epochs} of {version} agent model.")
print(
Expand All @@ -104,7 +85,7 @@ def load_activations(epochs, version="single", verbose=True):
)
print(f"positions has shape {positions[0].shape}.")

return activations, rate_maps, state_points, positions
return activations, rate_maps, state_points, positions, g_s


# def plot_rate_map(indices, num_plots, activations, title):
Expand Down Expand Up @@ -137,9 +118,8 @@ def load_activations(epochs, version="single", verbose=True):
# plt.show()



def plot_rate_map(indices, num_plots, activations, title):
rng = np.random.default_rng(seed=0)
def plot_rate_map(indices, num_plots, activations, title, seed=None):
rng = np.random.default_rng(seed=seed)
if indices is None:
idxs = rng.integers(0, activations.shape[0] - 1, num_plots)
else:
Expand Down
1 change: 1 addition & 0 deletions neurometry/datasets/rnn_grid_cells/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class Config:
n_avg = 50 # number of trajectories to average over for rate maps



# If you need to access the configuration as a dictionary
config = Config.__dict__

Expand Down
Loading

0 comments on commit 6852b45

Please sign in to comment.