Skip to content

Commit

Permalink
Merge pull request #7 from JonathanCrabbe/other_backbone
Browse files Browse the repository at this point in the history
Other Backbone
  • Loading branch information
nicolashuynh authored Jan 31, 2024
2 parents 276ad71 + 3eb8913 commit 0be30a2
Show file tree
Hide file tree
Showing 27 changed files with 1,378 additions and 13,277 deletions.
10 changes: 10 additions & 0 deletions cmd/conf/score_model/lstm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
_target_: fdiff.models.score_models.LSTMScoreModule
_partial_: true
d_model: 72
num_layers: 10
lr_max: 1.0e-3
fourier_noise_scaling: ${fourier_transform}
likelihood_weighting: False

defaults:
- noise_scheduler: vpsde
11 changes: 11 additions & 0 deletions cmd/conf/score_model/mlp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
_target_: fdiff.models.score_models.MLPScoreModule
_partial_: true
d_model: 72
d_mlp: 1024
num_layers: 10
lr_max: 1.0e-4
fourier_noise_scaling: ${fourier_transform}
likelihood_weighting: False

defaults:
- noise_scheduler: vpsde
2 changes: 0 additions & 2 deletions cmd/conf/score_model/noise_scheduler/customddpm.yaml

This file was deleted.

2 changes: 0 additions & 2 deletions cmd/conf/score_model/noise_scheduler/ddpm.yaml

This file was deleted.

5 changes: 3 additions & 2 deletions cmd/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from fdiff.models.score_models import ScoreModule
from fdiff.sampling.metrics import MetricCollection
from fdiff.sampling.sampler import DiffusionSampler
from fdiff.utils.extraction import dict_to_str, get_best_checkpoint
from fdiff.utils.extraction import dict_to_str, get_best_checkpoint, get_model_type
from fdiff.utils.fourier import idft


Expand Down Expand Up @@ -48,7 +48,8 @@ def __init__(self, cfg: DictConfig) -> None:

# Load score model from checkpoint
best_checkpoint_path = get_best_checkpoint(self.save_dir / "checkpoints")
self.score_model = ScoreModule.load_from_checkpoint(
model_type = get_model_type(train_cfg)
self.score_model = model_type.load_from_checkpoint(
checkpoint_path=best_checkpoint_path
)
if torch.cuda.is_available():
Expand Down
Loading

0 comments on commit 0be30a2

Please sign in to comment.