From 21dae65bb5a18a5ba1dd8c8bd4c4f5a4c4825214 Mon Sep 17 00:00:00 2001 From: Samriddhi Sinha Date: Tue, 9 Jul 2024 06:44:18 +0530 Subject: [PATCH] Added Verbosity as an argument --- README.md | 1 - groqeval/__init__.py | 4 ++- groqeval/metrics/answer_relevance.py | 7 +++--- groqeval/metrics/base_metric.py | 35 ++++++++++++++++----------- groqeval/metrics/bias.py | 7 +++--- groqeval/metrics/context_relevance.py | 7 +++--- groqeval/metrics/faithfulness.py | 8 +++--- groqeval/metrics/hallucination.py | 7 +++--- groqeval/metrics/toxicity.py | 6 ++--- pyproject.toml | 3 ++- requirements.txt | 3 ++- tests/test_evaluate.py | 17 ++++++++++--- 12 files changed, 65 insertions(+), 40 deletions(-) diff --git a/README.md b/README.md index 8fb4858..76f5f7c 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,6 @@ evaluator = GroqEval(api_key=API_KEY) The evaluator is the central orchestrator that initializes the metrics. ```python -from groqeval.evaluate import GroqEval metrics = evaluator(metric_name, **kwargs) ``` diff --git a/groqeval/__init__.py b/groqeval/__init__.py index ce7c33d..63be258 100644 --- a/groqeval/__init__.py +++ b/groqeval/__init__.py @@ -1 +1,3 @@ -from groqeval.evaluate import GroqEval \ No newline at end of file +from groqeval.evaluate import GroqEval + +__all__ = ["GroqEval"] diff --git a/groqeval/metrics/answer_relevance.py b/groqeval/metrics/answer_relevance.py index a4ef493..fcb46c7 100644 --- a/groqeval/metrics/answer_relevance.py +++ b/groqeval/metrics/answer_relevance.py @@ -12,8 +12,8 @@ class AnswerRelevance(BaseMetric): relevance to the original question, helping to gauge the utility and appropriateness of the model's responses. """ - def __init__(self, groq_client: Groq, output: str, prompt: str): - super().__init__(groq_client) + def __init__(self, groq_client: Groq, output: str, prompt: str, **kwargs): + super().__init__(groq_client, kwargs.get('verbose')) self.output = output self.prompt = prompt self.check_data_types(prompt=prompt, output=output) @@ -66,13 +66,13 @@ def output_decomposition(self): {"role": "system", "content": self.output_decomposition_prompt}, {"role": "user", "content": self.output} ] - print(messages) response = self.groq_chat_completion( messages=messages, model="llama3-70b-8192", temperature=0, response_format={"type": "json_object"} ) + self.logger.info("Decomposition of the Output into Statements: %s", response.choices[0].message.content) return Output.model_validate_json(response.choices[0].message.content) def score_relevance(self): @@ -93,6 +93,7 @@ def score_relevance(self): temperature=0, response_format={"type": "json_object"} ) + self.logger.info("Breakdown of the Answer Relevance Score: %s", response.choices[0].message.content) return ScoredOutput.model_validate_json(response.choices[0].message.content), json.loads(response.choices[0].message.content) def score(self): diff --git a/groqeval/metrics/base_metric.py b/groqeval/metrics/base_metric.py index d4caac2..5cd31ee 100644 --- a/groqeval/metrics/base_metric.py +++ b/groqeval/metrics/base_metric.py @@ -1,9 +1,16 @@ +import logging +from groq import Groq + class BaseMetric: """ The Base Metric class. """ - def __init__(self, groq_client): + def __init__(self, groq_client: Groq, verbose: bool = None): self.groq_client = groq_client + self.logger = logging.getLogger(__name__) + if verbose: + self.logger.setLevel(logging.INFO) + def groq_chat_completion(self, messages, model, temperature=0.5, response_format=None): """ @@ -15,7 +22,6 @@ def groq_chat_completion(self, messages, model, temperature=0.5, response_format temperature=temperature, response_format=response_format ) - print(chat_completion.choices[0].message.content) return chat_completion def check_data_types(self, **kwargs): @@ -23,19 +29,20 @@ def check_data_types(self, **kwargs): Checks for empty strings in the arguments """ for key, value in kwargs.items(): - if key != "context": - if value == "": - raise ValueError(f"'{key}' cannot be an empty string.") - if not isinstance(value, str): - raise TypeError(f"'{key}' must be a string") - else: - if len(value) == 0: - raise ValueError(f"'{key}' cannot be an empty list.") - if not isinstance(value, list): - raise TypeError(f"'{key}' must be a list of strings") + if key != "verbose": + if key != "context": + if value == "": + raise ValueError(f"'{key}' cannot be an empty string.") + if not isinstance(value, str): + raise TypeError(f"'{key}' must be a string") else: - if not all(isinstance(item, str) for item in value): - raise TypeError(f"All items in '{key}' must be strings") + if len(value) == 0: + raise ValueError(f"'{key}' cannot be an empty list.") + if not isinstance(value, list): + raise TypeError(f"'{key}' must be a list of strings") + else: + if not all(isinstance(item, str) for item in value): + raise TypeError(f"All items in '{key}' must be strings") diff --git a/groqeval/metrics/bias.py b/groqeval/metrics/bias.py index eab57a3..3ef34ad 100644 --- a/groqeval/metrics/bias.py +++ b/groqeval/metrics/bias.py @@ -12,8 +12,8 @@ class Bias(BaseMetric): context-driven expressions. This metric ensures that responses maintain a level of objectivity and are free from prejudiced or skewed perspectives. """ - def __init__(self, groq_client: Groq, output: str, prompt: str): - super().__init__(groq_client) + def __init__(self, groq_client: Groq, output: str, prompt: str, **kwargs): + super().__init__(groq_client, kwargs.get('verbose')) self.output = output self.prompt = prompt self.check_data_types(prompt=prompt, output=output) @@ -70,13 +70,13 @@ def output_decomposition(self): {"role": "system", "content": self.output_decomposition_prompt}, {"role": "user", "content": self.output} ] - print(messages) response = self.groq_chat_completion( messages=messages, model="llama3-70b-8192", temperature=0, response_format={"type": "json_object"} ) + self.logger.info("Decomposition of the Output into Opinions: %s", response.choices[0].message.content) return Output.model_validate_json(response.choices[0].message.content) def score_bias(self): @@ -97,6 +97,7 @@ def score_bias(self): temperature=0, response_format={"type": "json_object"} ) + self.logger.info("Breakdown of the Bias Score: %s", response.choices[0].message.content) return ScoredOutput.model_validate_json(response.choices[0].message.content), json.loads(response.choices[0].message.content) def score(self): diff --git a/groqeval/metrics/context_relevance.py b/groqeval/metrics/context_relevance.py index d481667..aa12d09 100644 --- a/groqeval/metrics/context_relevance.py +++ b/groqeval/metrics/context_relevance.py @@ -13,8 +13,8 @@ class ContextRelevance(BaseMetric): to the generator is pertinent and likely to enhance the quality and accuracy of the generated responses. """ - def __init__(self, groq_client: Groq, context: List[str], prompt: str): - super().__init__(groq_client) + def __init__(self, groq_client: Groq, context: List[str], prompt: str, **kwargs): + super().__init__(groq_client, kwargs.get('verbose')) self.context = context self.prompt = prompt self.check_data_types(prompt=prompt, context=context) @@ -79,13 +79,13 @@ def context_decomposition(self): {"role": "system", "content": self.context_decomposition_prompt}, {"role": "user", "content": self.format_retrieved_context} ] - print(messages) response = self.groq_chat_completion( messages=messages, model="llama3-70b-8192", temperature=0, response_format={"type": "json_object"} ) + self.logger.info("Decomposition of the Context into Statements: %s", response.choices[0].message.content) return Context.model_validate_json(response.choices[0].message.content) def score_relevance(self): @@ -110,6 +110,7 @@ def score_relevance(self): temperature=0, response_format={"type": "json_object"} ) + self.logger.info("Breakdown of the Context Relevance Score: %s", response.choices[0].message.content) return ScoredContext.model_validate_json(response.choices[0].message.content), json.loads(response.choices[0].message.content) def score(self): diff --git a/groqeval/metrics/faithfulness.py b/groqeval/metrics/faithfulness.py index a85ecef..e159f53 100644 --- a/groqeval/metrics/faithfulness.py +++ b/groqeval/metrics/faithfulness.py @@ -12,8 +12,8 @@ class Faithfulness(BaseMetric): content is not only relevant but also accurate and truthful with respect to the given context, critical for maintaining the integrity and reliability of the model's responses. """ - def __init__(self, groq_client: Groq, context: List[str], output: str): - super().__init__(groq_client) + def __init__(self, groq_client: Groq, context: List[str], output: str, **kwargs): + super().__init__(groq_client, kwargs.get('verbose')) self.context = context self.output = output self.check_data_types(context=context, output=output) @@ -80,13 +80,13 @@ def output_decomposition(self): {"role": "system", "content": self.output_decomposition_prompt}, {"role": "user", "content": self.output} ] - print(messages) response = self.groq_chat_completion( messages=messages, model="llama3-70b-8192", temperature=0, response_format={"type": "json_object"} ) + self.logger.info("Decomposition of the Output into Claims: %s", response.choices[0].message.content) return Output.model_validate_json(response.choices[0].message.content) def score_faithfulness(self): @@ -106,13 +106,13 @@ def score_faithfulness(self): {"role": "system", "content": self.faithfulness_prompt}, {"role": "user", "content": json.dumps({"sentences": [s.string for s in coherent_sentences]}, indent=2)} ] - print(messages) response = self.groq_chat_completion( messages=messages, model="llama3-70b-8192", temperature=0, response_format={"type": "json_object"} ) + self.logger.info("Breakdown of the Faithfulness Score: %s", response.choices[0].message.content) return ScoredOutput.model_validate_json(response.choices[0].message.content), json.loads(response.choices[0].message.content) def score(self): diff --git a/groqeval/metrics/hallucination.py b/groqeval/metrics/hallucination.py index 305ba83..c1dc959 100644 --- a/groqeval/metrics/hallucination.py +++ b/groqeval/metrics/hallucination.py @@ -13,8 +13,8 @@ class Hallucination(BaseMetric): This is crucial for ensuring that the generated outputs remain grounded in the provided context and do not mislead or introduce inaccuracies. """ - def __init__(self, groq_client: Groq, context: List[str], output: str): - super().__init__(groq_client) + def __init__(self, groq_client: Groq, context: List[str], output: str, **kwargs): + super().__init__(groq_client, kwargs.get('verbose')) self.context = context self.output = output self.check_data_types(context=context, output=output) @@ -89,13 +89,13 @@ def context_decomposition(self): {"role": "system", "content": self.context_decomposition_prompt}, {"role": "user", "content": self.format_retrieved_context} ] - print(messages) response = self.groq_chat_completion( messages=messages, model="llama3-70b-8192", temperature=0, response_format={"type": "json_object"} ) + self.logger.info("Decomposition of the Context into Statements: %s", response.choices[0].message.content) return Context.model_validate_json(response.choices[0].message.content) def score_hallucination(self): @@ -116,6 +116,7 @@ def score_hallucination(self): temperature=0, response_format={"type": "json_object"} ) + self.logger.info("Breakdown of the Hallucination Score: %s", response.choices[0].message.content) return ScoredContext.model_validate_json(response.choices[0].message.content), json.loads(response.choices[0].message.content) def score(self): diff --git a/groqeval/metrics/toxicity.py b/groqeval/metrics/toxicity.py index e61dc37..35bf13c 100644 --- a/groqeval/metrics/toxicity.py +++ b/groqeval/metrics/toxicity.py @@ -12,8 +12,8 @@ class Toxicity(BaseMetric): wider consumption, identifying any language that could be considered insulting, aggressive, or otherwise damaging. """ - def __init__(self, groq_client: Groq, output: str, prompt: str): - super().__init__(groq_client) + def __init__(self, groq_client: Groq, output: str, prompt: str, **kwargs): + super().__init__(groq_client, kwargs.get('verbose')) self.output = output self.prompt = prompt self.check_data_types(prompt=prompt, output=output) @@ -69,13 +69,13 @@ def output_decomposition(self): {"role": "system", "content": self.output_decomposition_prompt}, {"role": "user", "content": self.output} ] - print(messages) response = self.groq_chat_completion( messages=messages, model="llama3-70b-8192", temperature=0, response_format={"type": "json_object"} ) + self.logger.info("Breakdown of the Toxicity Score: %s", response.choices[0].message.content) return Output.model_validate_json(response.choices[0].message.content) def score_toxicity(self): diff --git a/pyproject.toml b/pyproject.toml index e593a16..cc53b18 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,8 @@ readme = "README.md" requires-python = ">=3.10" dependencies = [ - "groq==0.9.0" + "groq>=0.9.0", + "pydantic>=2.7.4" ] [tool.twine] diff --git a/requirements.txt b/requirements.txt index 8ceb2aa..f70ce82 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ -groq==0.9.0 \ No newline at end of file +groq==0.9.0 +pydantic==2.7.4 \ No newline at end of file diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py index d0bd6ff..f3690cb 100644 --- a/tests/test_evaluate.py +++ b/tests/test_evaluate.py @@ -1,6 +1,6 @@ import os import importlib -from typing import List, Dict +import pytest from conftest import get_class_args, generate_random_value def metricize(file_name: str): @@ -26,5 +26,16 @@ def test_load_metrics(evaluator, metrics_folder, metrics_module): class_ = getattr(module, class_name) class_args = get_class_args(class_) random_args = {name: generate_random_value(param) for name, param in class_args.items()} - print(class_name, random_args) - assert type(evaluator(module_name, **random_args)) == class_ \ No newline at end of file + assert type(evaluator(module_name, **random_args)) == class_ + +def test_load_base_metric(evaluator, metrics_module): + module_name = "base_metric" + module_path = f'{metrics_module}.{"base_metric"}' + module = importlib.import_module(module_path) + class_name = metricize(module_name) + + class_ = getattr(module, class_name) + class_args = get_class_args(class_) + random_args = {name: generate_random_value(param) for name, param in class_args.items()} + with pytest.raises(TypeError, match=f"{class_name} is not a valid metric class"): + base_metric = evaluator(module_name, **random_args) \ No newline at end of file