From 4a4c03d54c565aac934c798af5d1e37f48eb3f40 Mon Sep 17 00:00:00 2001 From: namsaraeva Date: Fri, 24 May 2024 11:10:36 +0200 Subject: [PATCH] unsqueeze --- src/sparcscore/ml/plmodels.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/sparcscore/ml/plmodels.py b/src/sparcscore/ml/plmodels.py index 8bb2d15..ced297f 100644 --- a/src/sparcscore/ml/plmodels.py +++ b/src/sparcscore/ml/plmodels.py @@ -204,40 +204,43 @@ def configure_optimizers(self): def training_step(self, batch): data, target = batch + target = target.unsqueeze(1) output = self.network(data) # Forward pass, only one output - loss = F.mse_loss(output, target.unsqueeze(1)) # L2 loss + loss = F.mse_loss(output, target) # L2 loss # accuracy metrics for regression??? self.log('loss/train', loss, on_step=False, on_epoch=True, prog_bar=True) - self.log('mse/train', self.mse(output, target.unsqueeze(1)), on_step=False, on_epoch=True, prog_bar=True) - self.log('mae/train', self.mae(output, target.unsqueeze(1)), on_step=False, on_epoch=True, prog_bar=True) + self.log('mse/train', self.mse(output, target), on_step=False, on_epoch=True, prog_bar=True) + self.log('mae/train', self.mae(output, target), on_step=False, on_epoch=True, prog_bar=True) return {'loss': loss, 'predictions': output, 'targets': target} def validation_step(self, batch): data, target = batch + target = target.unsqueeze(1) output = self.network(data) - loss = F.mse_loss(output, target.unsqueeze(1)) + loss = F.mse_loss(output, target) # accuracy metrics for regression??? self.log('loss/val', loss, on_step=False, on_epoch=True, prog_bar=True) - self.log('mse/val', self.mse(output, target.unsqueeze(1)), on_step=False, on_epoch=True, prog_bar=True) - self.log('mae/val', self.mae(output, target.unsqueeze(1)), on_step=False, on_epoch=True, prog_bar=True) + self.log('mse/val', self.mse(output, target), on_step=False, on_epoch=True, prog_bar=True) + self.log('mae/val', self.mae(output, target), on_step=False, on_epoch=True, prog_bar=True) return {'loss': loss, 'predictions': output, 'targets': target} def test_step(self, batch): data, target = batch + target = target.unsqueeze(1) output = self.network(data) - loss = F.mse_loss(output, target.unsqueeze(1)) + loss = F.mse_loss(output, target) # accuracy metrics for regression??? self.log('loss/test', loss, on_step=False, on_epoch=True, prog_bar=True) - self.log('mse/test', self.mse(output, target.unsqueeze(1)), on_step=False, on_epoch=True, prog_bar=True) - self.log('mae/test', self.mae(output, target.unsqueeze(1)), on_step=False, on_epoch=True, prog_bar=True) + self.log('mse/test', self.mse(output, target), on_step=False, on_epoch=True, prog_bar=True) + self.log('mae/test', self.mae(output, target), on_step=False, on_epoch=True, prog_bar=True) return {'loss': loss, 'predictions': output, 'targets': target}