Skip to content

Commit

Permalink
Make MLP module residual
Browse files Browse the repository at this point in the history
  • Loading branch information
JonathanCrabbe committed Jan 25, 2024
1 parent 49a7fa3 commit ee9c4c8
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
5 changes: 3 additions & 2 deletions cmd/conf/score_model/mlp.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
_target_: fdiff.models.score_models.MLPScoreModule
_partial_: true
d_model: 512
d_model: 128
d_mlp: 1024
num_layers: 10
lr_max: 1.0e-3
lr_max: 1.0e-4
fourier_noise_scaling: ${fourier_transform}
likelihood_weighting: False

Expand Down
15 changes: 10 additions & 5 deletions src/fdiff/models/score_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,8 @@ def __init__(
max_len: int,
noise_scheduler: DDPMScheduler | SDE,
fourier_noise_scaling: bool = True,
d_model: int = 60,
d_model: int = 72,
d_mlp: int = 512,
num_layers: int = 3,
num_training_steps: int = 1000,
lr_max: float = 1e-3,
Expand All @@ -216,9 +217,12 @@ def __init__(
self.unembedder = nn.Linear(
in_features=d_model, out_features=max_len * n_channels
)
self.backbone = MLP(
in_channels=d_model,
hidden_channels=[d_model] * num_layers,

self.backbone = nn.ModuleList(
[
MLP(in_channels=d_model, hidden_channels=[d_mlp, d_model], dropout=0.1)
for _ in range(num_layers)
]
)
self.pos_encoder = None

Expand All @@ -245,7 +249,8 @@ def forward(self, batch: DiffusableBatch) -> torch.Tensor:
X = self.time_encoder(X, timesteps, use_time_axis=False)

# Backbone
X = self.backbone(X)
for layer in self.backbone:
X = X + layer(X)

# Channel unembedding
X = self.unembedder(X)
Expand Down

0 comments on commit ee9c4c8

Please sign in to comment.