Skip to content

Commit

Permalink
Merge pull request #152 from geometric-intelligence/fix_serialization
Browse files Browse the repository at this point in the history
Fix serialization
  • Loading branch information
franciscoeacosta authored Apr 26, 2024
2 parents f2cba7c + 8e9a05a commit e21ca8a
Show file tree
Hide file tree
Showing 21 changed files with 211 additions and 479 deletions.
35 changes: 22 additions & 13 deletions neurometry/curvature/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,18 @@
import torch
from scipy.signal import savgol_filter

import neurometry.curvature.datasets as datasets
from neurometry.curvature.datasets.experimental import load_neural_activity
from neurometry.curvature.datasets.gridcells import load_grid_cells_synthetic
from neurometry.curvature.datasets.synthetic import (
load_images,
load_place_cells,
load_points,
load_projected_images,
load_s1_synthetic,
load_s2_synthetic,
load_t2_synthetic,
load_three_place_cells,
)


def load(config):
Expand All @@ -31,7 +42,7 @@ def load(config):
test dataset.
"""
if config.dataset_name == "experimental":
dataset, labels = datasets.experimental.load_neural_activity(
dataset, labels = load_neural_activity(
expt_id=config.expt_id, timestep_microsec=config.timestep_microsec
)
dataset = dataset[labels["velocities"] > 5]
Expand Down Expand Up @@ -72,24 +83,22 @@ def load(config):
labels = labels[labels["gains"] == gain]

elif config.dataset_name == "synthetic":
dataset, labels = datasets.synthetic.load_place_cells()
dataset, labels = load_place_cells()
dataset = np.log(dataset.astype(np.float32) + 1)
dataset = (dataset - np.min(dataset)) / (np.max(dataset) - np.min(dataset))
elif config.dataset_name == "images":
dataset, labels = datasets.synthetic.load_images(img_size=config.img_size)
dataset, labels = load_images(img_size=config.img_size)
dataset = (dataset - np.min(dataset)) / (np.max(dataset) - np.min(dataset))
height, width = dataset.shape[1:3]
dataset = dataset.reshape((-1, height * width))
elif config.dataset_name == "projected_images":
dataset, labels = datasets.synthetic.load_projected_images(
img_size=config.img_size
)
dataset, labels = load_projected_images(img_size=config.img_size)
dataset = (dataset - np.min(dataset)) / (np.max(dataset) - np.min(dataset))
elif config.dataset_name == "points":
dataset, labels = datasets.synthetic.load_points()
dataset, labels = load_points()
dataset = dataset.astype(np.float32)
elif config.dataset_name == "s1_synthetic":
dataset, labels = datasets.synthetic.load_s1_synthetic(
dataset, labels = load_s1_synthetic(
synthetic_rotation=config.synthetic_rotation,
n_times=config.n_times,
radius=config.radius,
Expand All @@ -100,7 +109,7 @@ def load(config):
geodesic_distortion_func=config.geodesic_distortion_func,
)
elif config.dataset_name == "s2_synthetic":
dataset, labels = datasets.synthetic.load_s2_synthetic(
dataset, labels = load_s2_synthetic(
synthetic_rotation=config.synthetic_rotation,
n_times=config.n_times,
radius=config.radius,
Expand All @@ -109,7 +118,7 @@ def load(config):
noise_var=config.noise_var,
)
elif config.dataset_name == "t2_synthetic":
dataset, labels = datasets.synthetic.load_t2_synthetic(
dataset, labels = load_t2_synthetic(
synthetic_rotation=config.synthetic_rotation,
n_times=config.n_times,
major_radius=config.major_radius,
Expand All @@ -119,7 +128,7 @@ def load(config):
noise_var=config.noise_var,
)
elif config.dataset_name == "grid_cells":
dataset, labels = datasets.gridcells.load_grid_cells_synthetic(
dataset, labels = load_grid_cells_synthetic(
grid_scale=config.grid_scale,
arena_dims=config.arena_dims,
n_cells=config.n_cells,
Expand All @@ -129,7 +138,7 @@ def load(config):
resolution=config.resolution,
)
elif config.dataset_name == "three_place_cells_synthetic":
dataset, labels = datasets.synthetic.load_three_place_cells()
dataset, labels = load_three_place_cells()
print(f"Dataset shape: {dataset.shape}.")
if type(dataset) == np.ndarray:
dataset_torch = torch.from_numpy(dataset)
Expand Down
28 changes: 12 additions & 16 deletions neurometry/curvature/default_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WANDB API KEY
# Find it here: https://wandb.ai/authorize
# Story it in file: api_key.txt (without extra line break)
api_key_path = os.path.join(os.getcwd(),"api_key.txt")
api_key_path = os.path.join(os.getcwd(), "api_key.txt")
with open(api_key_path) as f:
api_key = f.read()

Expand All @@ -28,18 +28,14 @@
if not os.path.exists(curvature_profiles_dir):
os.makedirs(curvature_profiles_dir)

print(configs_dir)
print(trained_models_dir)


# Hardware
device = "cuda" if torch.cuda.is_available() else "cpu"

# Can be replaced by logging.DEBUG or logging.WARNING
logging.basicConfig(level=logging.INFO)

# Results
project = "neurometry"
project = "topo-vae"
trained_model_path = None

### Fixed experiment parameters ###
Expand Down Expand Up @@ -164,10 +160,10 @@

# Only used of dataset_name in ["s1_synthetic", "s2_synthetic", "t2_synthetic"]
n_times = [2500] # , 2000] # actual number of times is sqrt_ntimes ** 2
embedding_dim = [5] # for s1 stopped at 5 (not done, but 3 was done)
embedding_dim = [3] # for s1 stopped at 5 (not done, but 3 was done)
geodesic_distortion_amp = [0.4]
# TODO: Add 0.03, possibly 0,000[1
noise_var = [0.1] # , 1e-2, 1e-1] 0.075, 0.1] #[
noise_var = [1e-5] # , 1e-2, 1e-1] 0.075, 0.1] #[

# Only used if dataset_name == "grid_cells"
grid_scale = [1.0]
Expand All @@ -188,7 +184,7 @@
scheduler = False
log_interval = 20
checkpt_interval = 20
n_epochs = 60 # 00 # 00 # 50 # 200 # 150 # 240
n_epochs = 400 # 00 # 00 # 50 # 200 # 150 # 240
sftbeta = 4.5 # beta parameter for softplus
alpha = 1.0 # weight for the reconstruction loss
beta = 0.03 # 0.03 # weight for KL loss
Expand All @@ -202,14 +198,14 @@
### Ray sweep hyperparameters ###
# --> Lists of values to sweep for each hyperparameter
# Except for lr_min and lr_max which are floats
lr_min = 0.0001
lr_min = [0.001] # 0.0001
lr_max = 0.1
batch_size = [16, 64, 128] # [16,32,64]
encoder_width = [200, 400] # [100,400] # , 100, 200, 300]
encoder_depth = [4, 10, 12] # [4,6,8] # , 10, 20, 50, 100]
decoder_width = [200, 400] # [100,400] # , 100, 200, 300]
decoder_depth = [4, 6, 8] # [4,6,8] # , 10, 20, 50, 100]
drop_out_p = [0, 0.1] # [0,0.1,0.2] # put probability p at 0. for no drop out
batch_size = [64] # [16,32,64]
encoder_width = [400] # [100,400] # , 100, 200, 300]
encoder_depth = [10] # [4,6,8] # , 10, 20, 50, 100]
decoder_width = [200] # [100,400] # , 100, 200, 300]
decoder_depth = [6] # [4,6,8] # , 10, 20, 50, 100]
drop_out_p = [0] # [0,0.1,0.2] # put probability p at 0. for no drop out
for p in drop_out_p:
assert p >= 0.0 and p <= 1, "Probability needs to be in [0, 1]"

Expand Down
7 changes: 4 additions & 3 deletions neurometry/curvature/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@

os.environ["GEOMSTATS_BACKEND"] = "pytorch"
import geomstats.backend as gs # noqa: E402
from geomstats.geometry.pullback_metric import PullbackMetric # noqa: E402
from geomstats.geometry.special_orthogonal import SpecialOrthogonal # noqa: E402

# import gph
from neurometry.curvature.datasets.synthetic import ( # noqa: E402
get_s1_synthetic_immersion,
get_s2_synthetic_immersion,
get_t2_synthetic_immersion,
)
from geomstats.geometry.pullback_metric import PullbackMetric # noqa: E402
from geomstats.geometry.special_orthogonal import SpecialOrthogonal # noqa: E402


def get_learned_immersion(model, config):
Expand Down Expand Up @@ -106,7 +106,8 @@ def get_z_grid(config, n_grid_points=100):
z_grid = torch.cartesian_prod(thetas, phis)
return z_grid

#TODO: change instantiation of PullbackMetric to match latest geomstats version

# TODO: change instantiation of PullbackMetric to match latest geomstats version
def _compute_curvature(z_grid, immersion, dim, embedding_dim):
"""Compute mean curvature vector and its norm at each point."""
neural_metric = PullbackMetric(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import torch
from labml_helpers.module import Module
from torch import nn
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(
self.dx_list = self._generate_dx_list(config.max_dr_trans)
# self.dx_list = self._generate_dx_list_continous(config.max_dr_trans)
self.scale_vector = np.zeros(self.num_blocks) + config.max_dr_isometry
self.rng = self.rng.default_rng()

def __iter__(self):
while True:
Expand All @@ -42,9 +43,9 @@ def _gen_data_kernel(self):
batch_size = self.config.batch_size
config = self.config

theta = np.random.random(size=int(batch_size * 1.5)) * 2 * np.pi
theta = self.rng.random(size=int(batch_size * 1.5)) * 2 * np.pi
dr = (
np.abs(np.random.normal(size=int(batch_size * 1.5)) * config.sigma_data)
np.abs(self.rng.normal(size=int(batch_size * 1.5)) * config.sigma_data)
* self.num_grid
)
dx = _dr_theta_to_dx(dr, theta)
Expand All @@ -57,7 +58,7 @@ def _gen_data_kernel(self):
x_max, x_min, dx = x_max[select_idx], x_min[select_idx], dx[select_idx]
assert len(dx) == batch_size

x = np.random.random(size=(batch_size, 2)) * (x_max - x_min) + x_min
x = self.rng.random(size=(batch_size, 2)) * (x_max - x_min) + x_min
x_prime = x + dx

return {"x": x, "x_prime": x_prime}
Expand All @@ -69,7 +70,7 @@ def _gen_data_trans_rnn(self):
n_steps = self.rnn_step
dx_list = self.dx_list

dx_idx = np.random.choice(len(dx_list), size=[n_traj * 10, n_steps])
dx_idx = self.rng.choice(len(dx_list), size=[n_traj * 10, n_steps])
dx = dx_list[dx_idx] # [N, T, 2]
dx_cumsum = np.cumsum(dx, axis=1) # [N, T, 2]

Expand All @@ -86,7 +87,7 @@ def _gen_data_trans_rnn(self):
x_start_max, x_start_min = x_start_max[select_idx], x_start_min[select_idx]
dx_cumsum = dx_cumsum[select_idx]
x_start = (
np.random.random((n_traj, 2)) * (x_start_max - x_start_min) + x_start_min
self.rng.random((n_traj, 2)) * (x_start_max - x_start_min) + x_start_min
)
x_start = x_start[:, None] # [N, 1, 2]
x_start = np.round(x_start - 0.5)
Expand All @@ -99,13 +100,13 @@ def _gen_data_iso_numerical(self):
batch_size = self.config.batch_size
config = self.config

theta = np.random.random(size=(batch_size, 2)) * 2 * np.pi
dr = np.sqrt(np.random.random(size=(batch_size, 1))) * config.max_dr_isometry
theta = self.rng.random(size=(batch_size, 2)) * 2 * np.pi
dr = np.sqrt(self.rng.random(size=(batch_size, 1))) * config.max_dr_isometry
dx = _dr_theta_to_dx(dr, theta) # [N, 2, 2]

x_max = np.fmin(self.num_grid - 0.5, np.min(self.num_grid - 0.5 - dx, axis=1))
x_min = np.fmax(-0.5, np.max(-0.5 - dx, axis=1))
x = np.random.random(size=(batch_size, 2)) * (x_max - x_min) + x_min
x = self.rng.random(size=(batch_size, 2)) * (x_max - x_min) + x_min
x_plus_dx1 = x + dx[:, 0]
x_plus_dx2 = x + dx[:, 1]

Expand All @@ -117,18 +118,18 @@ def _gen_data_iso_numerical_adaptive(self):
config = self.config

theta = (
np.random.random(size=(batch_size, num_blocks, 2)) * 2 * np.pi
self.rng.random(size=(batch_size, num_blocks, 2)) * 2 * np.pi
) # (batch_size, num_blocks, 2)
dr = (
np.sqrt(np.random.random(size=(batch_size, num_blocks, 1)))
np.sqrt(self.rng.random(size=(batch_size, num_blocks, 1)))
* np.tile(self.scale_vector, (batch_size, 1))[:, :, None]
) # (batch_size, num_blocks, 1)
dx = _dr_theta_to_dx(dr, theta) # [N, num_blocks, 2, 2]

x_max = np.fmin(self.num_grid - 0.5, np.min(self.num_grid - 0.5 - dx, axis=2))
x_min = np.fmax(-0.5, np.max(-0.5 - dx, axis=2))
x = (
np.random.random(size=(batch_size, num_blocks, 2)) * (x_max - x_min) + x_min
self.rng.random(size=(batch_size, num_blocks, 2)) * (x_max - x_min) + x_min
) # (batch_size, num_blocks, 2)
x_plus_dx1 = x + dx[:, :, 0]
x_plus_dx2 = x + dx[:, :, 1]
Expand Down Expand Up @@ -157,9 +158,9 @@ def _generate_dx_list_continous(self, max_dr):
dx_list = []
batch_size = self.config.batch_size

dr = np.sqrt(np.random.random(size=(batch_size,))) * max_dr
np.random.shuffle(dr)
theta = np.random.random(size=(batch_size,)) * 2 * np.pi
dr = np.sqrt(self.rng.random(size=(batch_size,))) * max_dr
self.rng.shuffle(dr)
theta = self.rng.random(size=(batch_size,)) * 2 * np.pi

dx = _dr_theta_to_dx(dr, theta)

Expand Down Expand Up @@ -202,7 +203,7 @@ def _gen_trajectory_vis(self, n_traj, n_steps):
x_start = np.reshape([5, 5], newshape=(1, 1, 2)) # [1, 1, 2]
dx_idx_pool = np.where((dx_list[:, 0] >= -1) & (dx_list[:, 1] >= -1))[0]
# dx_idx_pool = np.where((dx_list[:, 0] >= 0) & (dx_list[:, 1] >= -1))[0]
dx_idx = np.random.choice(dx_idx_pool, size=[n_traj * 50, n_steps])
dx_idx = self.rng.choice(dx_idx_pool, size=[n_traj * 50, n_steps])
dx = dx_list[dx_idx]
dx_cumsum = np.cumsum(dx, axis=1) # [N, T, 2]

Expand All @@ -224,7 +225,7 @@ def _gen_trajectory(self, n_traj, n_steps):
# uniformly wihtin the whole region.
dx_list = self.dx_list

dx_idx = np.random.choice(len(dx_list), size=[n_traj * 10, n_steps])
dx_idx = self.rng.choice(len(dx_list), size=[n_traj * 10, n_steps])
dx = dx_list[dx_idx] # [N, T, 2]
dx_cumsum = np.cumsum(dx, axis=1) # [N, T, 2]

Expand All @@ -241,7 +242,7 @@ def _gen_trajectory(self, n_traj, n_steps):
x_start_max, x_start_min = x_start_max[select_idx], x_start_min[select_idx]
dx_cumsum = dx_cumsum[select_idx]
x_start = (
np.random.random((n_traj, 2)) * (x_start_max - x_start_min) + x_start_min
self.rng.random((n_traj, 2)) * (x_start_max - x_start_min) + x_start_min
)
x_start = x_start[:, None] # [N, 1, 2]
x_start = np.round(x_start - 0.5)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -507,10 +507,9 @@ def get_grid_code(codebook, x, num_grid):
align_corners=False,
) # [1, C, 1, N]

v_x = torch.squeeze(torch.squeeze(v_x, 0), 1).transpose(0, 1) # [N, C]
# v_x = v_x.squeeze().transpose(0, 1)

return v_x
return torch.squeeze(torch.squeeze(v_x, 0), 1).transpose(0, 1) # [N, C]


def get_grid_code_block(codebook, x, num_grid, block_size):
Expand All @@ -537,7 +536,10 @@ def get_grid_code_int(codebook, x, num_grid):

# query the 2D codebook, no interpolation
v_x = torch.vstack(
[codebook[:, i, j] for i, j in zip(x_normalized[:, 0], x_normalized[:, 1], strict=False)]
[
codebook[:, i, j]
for i, j in zip(x_normalized[:, 0], x_normalized[:, 1], strict=False)
]
)
# v_x = v_x.squeeze().transpose(0, 1) # [N, C]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

"""Grid score calculations."""


import math

import matplotlib.pyplot as plt
Expand Down
Loading

0 comments on commit e21ca8a

Please sign in to comment.