diff --git a/test/evaluation/harness/rag/test_harness.py b/test/evaluation/harness/rag/test_harness.py index 65e528c2..2918420a 100644 --- a/test/evaluation/harness/rag/test_harness.py +++ b/test/evaluation/harness/rag/test_harness.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union import pytest from haystack_experimental.evaluation.harness.rag import ( @@ -106,7 +106,9 @@ def run(self, query: str) -> Dict[str, Any]: @component class MockEvaluator: - def __init__(self, metric: RAGEvaluationMetric) -> None: + def __init__(self, metric: Union[str, RAGEvaluationMetric]) -> None: + if isinstance(metric, str): + metric = RAGEvaluationMetric(metric) self.metric = metric io_map = { @@ -132,6 +134,9 @@ def __init__(self, metric: RAGEvaluationMetric) -> None: self.__haystack_input__ = io_map[metric].__haystack_input__ self.__haystack_output__ = io_map[metric].__haystack_output__ + def to_dict(self): + return default_to_dict(self, metric=str(self.metric)) + @staticmethod def default_output(metric) -> Dict[str, Any]: if metric in (