diff --git a/ludwig/modules/metric_modules.py b/ludwig/modules/metric_modules.py index 9d6e26335f2..0281f967a21 100644 --- a/ludwig/modules/metric_modules.py +++ b/ludwig/modules/metric_modules.py @@ -136,7 +136,7 @@ class RMSEMetric(MeanSquaredError, LudwigMetric): """Root mean squared error metric.""" def __init__(self, **kwargs): - super().__init__(squared=False, **kwargs) + super().__init__(squared=False) @register_metric(PRECISION, [BINARY], MAXIMIZE, PROBABILITIES) diff --git a/requirements.txt b/requirements.txt index 0c3deeebf5e..1f997e72c4b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,7 +26,7 @@ marshmallow-jsonschema marshmallow-dataclass==8.5.4 tensorboard nltk # Required for rouge scores. -torchmetrics>=0.11.0,<=0.11.4 +torchmetrics>=0.11.0 torchinfo filelock psutil==5.9.4