diff --git a/src/sparcscore/ml/models.py b/src/sparcscore/ml/models.py index 8a046e4..51f10a0 100644 --- a/src/sparcscore/ml/models.py +++ b/src/sparcscore/ml/models.py @@ -191,7 +191,7 @@ def forward(self, x): x = self.features(x) print("x.shape after features", x.shape) - #x = torch.flatten(x, 1) + x = torch.flatten(x, 1) print("x.shape after flatten", x.shape) x = self.classifier(x)