diff --git a/src/sparcscore/ml/plmodels.py b/src/sparcscore/ml/plmodels.py index 6c97da4..786f818 100644 --- a/src/sparcscore/ml/plmodels.py +++ b/src/sparcscore/ml/plmodels.py @@ -168,7 +168,7 @@ def test_step(self, batch, batch_idx): class RegressionModel(pl.LightningModule): - def __init__(self, model_type="VGG2_regression" **kwargs): + def __init__(self, model_type="VGG2_regression", **kwargs): super().__init__() self.save_hyperparameters()