Skip to content

Commit

Permalink
make scheduling stateless wrt sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
DruvPai committed Feb 4, 2025
1 parent fbb1d1d commit 90b6178
Show file tree
Hide file tree
Showing 15 changed files with 234 additions and 189 deletions.
11 changes: 7 additions & 4 deletions demo.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import cast

Check failure on line 1 in demo.py

View workflow job for this annotation

GitHub Actions / test

Ruff (F401)

demo.py:1:20: F401 `typing.cast` imported but unused
import lightning
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from diffusionlab.distributions.gmm import IsoHomoGMMDistribution
from diffusionlab.model import DiffusionModel
from diffusionlab.samplers import FMSampler
from diffusionlab.sampler import FMSampler
from diffusionlab.vector_fields import VectorField, VectorFieldType

lightning.seed_everything(42)
Expand Down Expand Up @@ -36,7 +37,8 @@ def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
t_min = 0.01
t_max = 0.99
L = 100
sampler = FMSampler(is_stochastic=False, t_min=t_min, t_max=t_max, L=L)
train_ts_hparams = {"t_min": t_min, "t_max": t_max, "L": L}
sampler = FMSampler(is_stochastic=False)

means = torch.randn(K, D) * 3
var = torch.tensor(0.5)
Expand Down Expand Up @@ -67,9 +69,10 @@ def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
sampler=sampler,
vector_field_type=VectorFieldType.EPS,
optimizer=optimizer,
scheduler=scheduler,
lr_scheduler=scheduler,
batchwise_val_metrics={},
overall_val_metrics={},
train_ts_hparams=train_ts_hparams,
t_loss_weights=lambda t: torch.ones_like(t),
t_loss_probs=lambda t: torch.ones_like(t) / L,
N_noise_per_sample=10,
Expand All @@ -91,7 +94,7 @@ def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
N_sample = 20
X0 = torch.randn(N_sample, D)
Zs = torch.randn(L - 1, N_sample, D, device=X0.device)
X_sample = sampler.sample(sampling_vector_field, X0, Zs)
X_sample = sampler.sample(sampling_vector_field, X0, Zs, model.train_ts)

distances = torch.cdist(X_sample, X_train).min(dim=1)
print(distances)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "diffusionlab"
version = "1.1.1"
version = "1.2.0"
description = "Easy no-frills Pytorch implementations of common abstractions for diffusion models."
readme = "README.md"
requires-python = ">=3.12"
Expand Down
2 changes: 1 addition & 1 deletion src/diffusionlab/distributions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from diffusionlab.samplers import Sampler
from diffusionlab.sampler import Sampler
from diffusionlab.vector_fields import VectorFieldType, convert_vector_field_type


Expand Down
2 changes: 1 addition & 1 deletion src/diffusionlab/distributions/empirical.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch.utils.data import DataLoader

from diffusionlab.distributions.base import Distribution
from diffusionlab.samplers import Sampler
from diffusionlab.sampler import Sampler
from diffusionlab.utils import pad_shape_back


Expand Down
2 changes: 1 addition & 1 deletion src/diffusionlab/distributions/gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch

from diffusionlab.distributions.base import Distribution
from diffusionlab.samplers import Sampler
from diffusionlab.sampler import Sampler
from diffusionlab.utils import logdet_pd, sqrt_psd, vector_lstsq


Expand Down
2 changes: 1 addition & 1 deletion src/diffusionlab/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from torch import nn

from diffusionlab.samplers import Sampler
from diffusionlab.sampler import Sampler
from diffusionlab.utils import pad_shape_back
from diffusionlab.vector_fields import VectorFieldType

Expand Down
67 changes: 31 additions & 36 deletions src/diffusionlab/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from lightning import LightningModule

from diffusionlab.loss import SamplewiseDiffusionLoss
from diffusionlab.samplers import Sampler
from diffusionlab.sampler import Sampler
from diffusionlab.vector_fields import VectorField, VectorFieldType


Expand All @@ -15,9 +15,10 @@ def __init__(
sampler: Sampler,
vector_field_type: VectorFieldType,
optimizer: optim.Optimizer,
scheduler: optim.lr_scheduler.LRScheduler,
lr_scheduler: optim.lr_scheduler.LRScheduler,
batchwise_val_metrics: Dict[str, nn.Module],
overall_val_metrics: Dict[str, nn.Module],
train_ts_hparams: Dict[str, float],
t_loss_weights: Callable[[torch.Tensor], torch.Tensor],
t_loss_probs: Callable[[torch.Tensor], torch.Tensor],
N_noise_per_sample: int,
Expand All @@ -27,24 +28,28 @@ def __init__(
self.vector_field_type: VectorFieldType = vector_field_type
self.sampler: Sampler = sampler
self.optimizer: optim.Optimizer = optimizer
self.scheduler: optim.lr_scheduler.LRScheduler = scheduler
self.batchwise_val_metrics: Dict[str, nn.Module] = batchwise_val_metrics
self.overall_val_metrics: Dict[str, nn.Module] = overall_val_metrics
self.lr_scheduler: optim.lr_scheduler.LRScheduler = lr_scheduler
self.batchwise_val_metrics: nn.ModuleDict = nn.ModuleDict(batchwise_val_metrics)
self.overall_val_metrics: nn.ModuleDict = nn.ModuleDict(overall_val_metrics)

self.t_loss_weights: Callable[[torch.Tensor], torch.Tensor] = t_loss_weights
self.t_loss_probs: Callable[[torch.Tensor], torch.Tensor] = t_loss_probs
self.N_noise_per_sample: int = N_noise_per_sample

self.t_loss_weights_precomputed: torch.Tensor = self.t_loss_weights(
self.sampler.schedule
)
self.t_loss_probs_precomputed: torch.Tensor = self.t_loss_probs(
self.sampler.schedule
)
self.samplewise_loss: SamplewiseDiffusionLoss = SamplewiseDiffusionLoss(
sampler, vector_field_type
)

self.precompute_train_schedule(train_ts_hparams)

def precompute_train_schedule(self, train_ts_hparams: Dict[str, float]) -> None:
train_ts = self.sampler.get_ts(train_ts_hparams).to(self.device, non_blocking=True)
train_ts_loss_weights: torch.Tensor = self.t_loss_weights(train_ts)
train_ts_loss_probs: torch.Tensor = self.t_loss_probs(train_ts)
self.register_buffer("train_ts", train_ts)
self.register_buffer("train_ts_loss_weights", train_ts_loss_weights)
self.register_buffer("train_ts_loss_probs", train_ts_loss_probs)

def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
return self.net(x, t)

Expand All @@ -54,16 +59,12 @@ def configure_optimizers(
Literal["optimizer", "lr_scheduler"],
optim.Optimizer | optim.lr_scheduler.LRScheduler,
]:
return {"optimizer": self.optimizer, "lr_scheduler": self.scheduler}
return {"optimizer": self.optimizer, "lr_scheduler": self.lr_scheduler}

def loss(
self, x: torch.Tensor, t: torch.Tensor, sample_weights: torch.Tensor
) -> torch.Tensor:
def loss(self, x: torch.Tensor, t: torch.Tensor, sample_weights: torch.Tensor) -> torch.Tensor:
x = torch.repeat_interleave(x, self.N_noise_per_sample, dim=0)
t = torch.repeat_interleave(t, self.N_noise_per_sample, dim=0)
sample_weights = torch.repeat_interleave(
sample_weights, self.N_noise_per_sample, dim=0
)
sample_weights = torch.repeat_interleave(sample_weights, self.N_noise_per_sample, dim=0)

eps = torch.randn_like(x)
xt = self.sampler.add_noise(x, t, eps)
Expand All @@ -74,13 +75,11 @@ def loss(
return mean_loss

def aggregate_loss(self, x: torch.Tensor) -> torch.Tensor:
t_idx = torch.multinomial(
self.t_loss_probs_precomputed, x.shape[0], replacement=True
).to(x.device, non_blocking=True)
t = self.sampler.schedule.to(x.device, non_blocking=True)[t_idx]
t_weights = self.t_loss_weights_precomputed.to(x.device, non_blocking=True)[
t_idx
]
t_idx = torch.multinomial(self.train_ts_loss_probs, x.shape[0], replacement=True).to(
self.device, non_blocking=True
)
t = self.train_ts[t_idx]
t_weights = self.train_ts_loss_weights[t_idx]
mean_loss = self.loss(x, t, t_weights)
return mean_loss

Expand All @@ -90,22 +89,18 @@ def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
return loss

def validation_step(
self, batch: torch.Tensor, batch_idx: int
) -> Dict[str, torch.Tensor]:
def validation_step(self, batch: torch.Tensor, batch_idx: int) -> Dict[str, torch.Tensor]:
x, metadata = batch
loss = self.aggregate_loss(x)
metric_values = {
metric_name: metric(x, metadata, self)
for metric_name, metric in self.batchwise_val_metrics.items()
}
metric_values = {}
for metric_name, metric in self.batchwise_val_metrics.items():
metric_values[metric_name] = metric(x, metadata, self)
metric_values["val_loss"] = loss
self.log_dict(metric_values, on_step=True, on_epoch=True, prog_bar=True)
return metric_values

def on_validation_epoch_end(self) -> None:
metric_values = {
metric_name: metric(self)
for metric_name, metric in self.overall_val_metrics.items()
}
metric_values = {}
for metric_name, metric in self.overall_val_metrics.items():
metric_values[metric_name] = metric(self)
self.log_dict(metric_values, on_step=False, on_epoch=True, prog_bar=True)
Loading

0 comments on commit 90b6178

Please sign in to comment.