Skip to content

Commit

Permalink
add variable loss function
Browse files Browse the repository at this point in the history
  • Loading branch information
namsaraeva committed May 29, 2024
1 parent b6aa14f commit 96921d5
Showing 1 changed file with 19 additions and 3 deletions.
22 changes: 19 additions & 3 deletions src/sparcscore/ml/plmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def test_step(self, batch, batch_idx):

class RegressionModel(pl.LightningModule):

def __init__(self, model_type="VGG2_regression", **kwargs):
def __init__(self, model_type="VGG2_regression" **kwargs):
super().__init__()
self.save_hyperparameters()

Expand Down Expand Up @@ -198,15 +198,31 @@ def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams["learning_rate"], weight_decay=self.hparams["weight_decay"])

else:
raise ValueError("No optimizer specified in hparams")
raise ValueError("No optimizer specified in hparams.")

return optimizer

def configure_loss(self):
if self.hparams["loss"] == "mse":
loss = F.mse_loss
elif self.hparams["loss"] == "huber":
if self.hparams["huber_delta"] is None:
self.hparams["huber_delta"] = 1.0
loss = F.huber_loss
else:
raise ValueError("No loss function specified in hparams.")

return loss

def training_step(self, batch):
data, target = batch
target = target.unsqueeze(1)
output = self.network(data) # Forward pass, only one output
loss = F.huber_loss(output, target, delta=1.0, reduction='mean') # consider looking at parameters again

if self.hparams["loss"] == "huber": # Huber loss
loss = loss(output, target, delta=self.hparams["huber_delta"], reduction='mean')
else: # MSE
loss = loss(output, target)

self.log('loss/train', loss, on_step=False, on_epoch=True, prog_bar=True)
self.log('mse/train', self.mse(output, target), on_epoch=True, prog_bar=True)
Expand Down

0 comments on commit 96921d5

Please sign in to comment.