diff --git a/src/sparcscore/ml/models.py b/src/sparcscore/ml/models.py index 51f10a0..8a046e4 100644 --- a/src/sparcscore/ml/models.py +++ b/src/sparcscore/ml/models.py @@ -191,7 +191,7 @@ def forward(self, x): x = self.features(x) print("x.shape after features", x.shape) - x = torch.flatten(x, 1) + #x = torch.flatten(x, 1) print("x.shape after flatten", x.shape) x = self.classifier(x)