diff --git a/src/sparcscore/ml/plmodels.py b/src/sparcscore/ml/plmodels.py index 6ff8fbf..fa124e0 100644 --- a/src/sparcscore/ml/plmodels.py +++ b/src/sparcscore/ml/plmodels.py @@ -205,6 +205,7 @@ def configure_optimizers(self): def training_step(self, batch): data, target = batch print("Training data shape: ", data.shape) + print("Batch size: ", data.size(0)) output = self.network(data) # Forward pass, only one output loss = F.mse_loss(output, target) # L2 loss