Skip to content

Commit

Permalink
Added Argument for Aggregation
Browse files Browse the repository at this point in the history
  • Loading branch information
djokester committed Jul 10, 2024
1 parent 21dae65 commit 6cbd4f8
Show file tree
Hide file tree
Showing 15 changed files with 124 additions and 92 deletions.
21 changes: 5 additions & 16 deletions groqeval/metrics/answer_relevance.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# groqeval/metrics/answer_relevance.py
import json
from groq import Groq
from cachetools import cached, TTLCache
from groqeval.models.output import Output, ScoredOutput
from groqeval.metrics.base_metric import BaseMetric

Expand Down Expand Up @@ -75,6 +76,7 @@ def output_decomposition(self):
self.logger.info("Decomposition of the Output into Statements: %s", response.choices[0].message.content)
return Output.model_validate_json(response.choices[0].message.content)

@cached(cache=TTLCache(maxsize=100, ttl=300))
def score_relevance(self):
"""
Each identified statement is then scored on a scale from 1 (completely irrelevant)
Expand All @@ -96,19 +98,6 @@ def score_relevance(self):
self.logger.info("Breakdown of the Answer Relevance Score: %s", response.choices[0].message.content)
return ScoredOutput.model_validate_json(response.choices[0].message.content), json.loads(response.choices[0].message.content)

def score(self):
"""
Aggregation of individual scores and final result.
"""
scored_output, output_dictionary = self.score_relevance()
if scored_output.scores:
average_score = sum([output.score for output in scored_output.scores]) / len(scored_output.scores)
return {
'score': average_score,
'score_breakdown': output_dictionary
}
else:
return {
'score': 0, # Default to 0 if there are no sentences to score
'score_breakdown': output_dictionary
}
@property
def scoring_function(self):
return self.score_relevance
38 changes: 32 additions & 6 deletions groqeval/metrics/base_metric.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import logging
import statistics
from abc import ABC,abstractmethod
from cachetools import cached, TTLCache
from groq import Groq

class BaseMetric:
class BaseMetric(ABC):
"""
The Base Metric class.
"""
def __init__(self, groq_client: Groq, verbose: bool = None):
self.groq_client = groq_client
self.aggregation = statistics.mean
self.logger = logging.getLogger(__name__)
if verbose:
self.logger.setLevel(logging.INFO)


@cached(cache=TTLCache(maxsize=100, ttl=300))
def groq_chat_completion(self, messages, model, temperature=0.5, response_format=None):
"""
Groq's chat completion API
Expand Down Expand Up @@ -43,8 +47,30 @@ def check_data_types(self, **kwargs):
else:
if not all(isinstance(item, str) for item in value):
raise TypeError(f"All items in '{key}' must be strings")



def score(self):

@property
@abstractmethod
def scoring_function(self):
"""
This property should be implemented by each child class
"""
raise NotImplementedError("This method should be overridden by subclasses")

def score(self, aggregation = None):
"""
Aggregation of individual scores and final result.
"""
if aggregation is not None:
self.aggregation = aggregation
scored_output, output_dictionary = self.scoring_function()
if scored_output.scores:
average_score = self.aggregation([output.score for output in scored_output.scores])
return {
'score': average_score,
'score_breakdown': output_dictionary
}
else:
return {
'score': 0, # Default to 0 if there are no sentences to score
'score_breakdown': output_dictionary
}
22 changes: 8 additions & 14 deletions groqeval/metrics/bias.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# groqeval/metrics/bias.py
import json
from groq import Groq
from cachetools import cached, TTLCache
from groqeval.models.output import Output, ScoredOutput
from groqeval.metrics.base_metric import BaseMetric

Expand All @@ -16,6 +17,8 @@ def __init__(self, groq_client: Groq, output: str, prompt: str, **kwargs):
super().__init__(groq_client, kwargs.get('verbose'))
self.output = output
self.prompt = prompt
self.aggregation = max

self.check_data_types(prompt=prompt, output=output)

@property
Expand Down Expand Up @@ -79,6 +82,7 @@ def output_decomposition(self):
self.logger.info("Decomposition of the Output into Opinions: %s", response.choices[0].message.content)
return Output.model_validate_json(response.choices[0].message.content)

@cached(cache=TTLCache(maxsize=100, ttl=300))
def score_bias(self):
"""
Each opinion in the output is scored on a scale from 1 (completely unbiased)
Expand All @@ -99,17 +103,7 @@ def score_bias(self):
)
self.logger.info("Breakdown of the Bias Score: %s", response.choices[0].message.content)
return ScoredOutput.model_validate_json(response.choices[0].message.content), json.loads(response.choices[0].message.content)

def score(self):
scored_output, output_dictionary = self.score_bias()
if scored_output.scores:
average_score = max([output.score for output in scored_output.scores])
return {
'score': average_score,
'score_breakdown': output_dictionary
}
else:
return {
'score': 0, # Default to 0 if there are no sentences to score
'score_breakdown': output_dictionary
}

@property
def scoring_function(self):
return self.score_bias
18 changes: 5 additions & 13 deletions groqeval/metrics/context_relevance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
from typing import List
from groq import Groq
from cachetools import cached, TTLCache
from groqeval.models.context import Context, ScoredContext
from groqeval.metrics.base_metric import BaseMetric

Expand Down Expand Up @@ -88,6 +89,7 @@ def context_decomposition(self):
self.logger.info("Decomposition of the Context into Statements: %s", response.choices[0].message.content)
return Context.model_validate_json(response.choices[0].message.content)

@cached(cache=TTLCache(maxsize=100, ttl=300))
def score_relevance(self):
"""
Each statement of context is evaluated to determine if it can be
Expand All @@ -113,16 +115,6 @@ def score_relevance(self):
self.logger.info("Breakdown of the Context Relevance Score: %s", response.choices[0].message.content)
return ScoredContext.model_validate_json(response.choices[0].message.content), json.loads(response.choices[0].message.content)

def score(self):
scored_context, output_dictionary = self.score_relevance()
if scored_context.scores:
average_score = sum([context.score for context in scored_context.scores]) / len(scored_context.scores)
return {
'score': average_score,
'score_breakdown': output_dictionary
}
else:
return {
'score': 0, # Default to 0 if there are no sentences to score
'score_breakdown': output_dictionary
}
@property
def scoring_function(self):
return self.score_relevance
22 changes: 7 additions & 15 deletions groqeval/metrics/faithfulness.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
from typing import List
from groq import Groq
from cachetools import cached, TTLCache
from groqeval.models.output import Output, ScoredOutput
from groqeval.metrics.base_metric import BaseMetric

Expand All @@ -15,7 +16,7 @@ class Faithfulness(BaseMetric):
def __init__(self, groq_client: Groq, context: List[str], output: str, **kwargs):
super().__init__(groq_client, kwargs.get('verbose'))
self.context = context
self.output = output
self.output = output
self.check_data_types(context=context, output=output)

@property
Expand Down Expand Up @@ -89,6 +90,7 @@ def output_decomposition(self):
self.logger.info("Decomposition of the Output into Claims: %s", response.choices[0].message.content)
return Output.model_validate_json(response.choices[0].message.content)

@cached(cache=TTLCache(maxsize=100, ttl=300))
def score_faithfulness(self):
"""
Claims are then scored on a scale from 1 to 10.
Expand All @@ -114,17 +116,7 @@ def score_faithfulness(self):
)
self.logger.info("Breakdown of the Faithfulness Score: %s", response.choices[0].message.content)
return ScoredOutput.model_validate_json(response.choices[0].message.content), json.loads(response.choices[0].message.content)

def score(self):
scored_output, output_dictionary = self.score_faithfulness()
if scored_output.scores:
average_score = sum([output.score for output in scored_output.scores]) / len(scored_output.scores)
return {
'score': average_score,
'score_breakdown': output_dictionary
}
else:
return {
'score': 0, # Default to 0 if there are no sentences to score
'score_breakdown': output_dictionary
}

@property
def scoring_function(self):
return self.score_faithfulness
18 changes: 5 additions & 13 deletions groqeval/metrics/hallucination.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
from typing import List
from groq import Groq
from cachetools import cached, TTLCache
from groqeval.models.context import Context, ScoredContext
from groqeval.metrics.base_metric import BaseMetric

Expand Down Expand Up @@ -98,6 +99,7 @@ def context_decomposition(self):
self.logger.info("Decomposition of the Context into Statements: %s", response.choices[0].message.content)
return Context.model_validate_json(response.choices[0].message.content)

@cached(cache=TTLCache(maxsize=100, ttl=300))
def score_hallucination(self):
"""
The hallucination metric evaluates the alignment between an output and its context,
Expand All @@ -119,16 +121,6 @@ def score_hallucination(self):
self.logger.info("Breakdown of the Hallucination Score: %s", response.choices[0].message.content)
return ScoredContext.model_validate_json(response.choices[0].message.content), json.loads(response.choices[0].message.content)

def score(self):
scored_context, output_dictionary = self.score_hallucination()
if scored_context.scores:
average_score = sum([context.score for context in scored_context.scores]) / len(scored_context.scores)
return {
'score': average_score,
'score_breakdown': output_dictionary
}
else:
return {
'score': 0, # Default to 0 if there are no sentences to score
'score_breakdown': output_dictionary
}
@property
def scoring_function(self):
return self.score_hallucination
21 changes: 8 additions & 13 deletions groqeval/metrics/toxicity.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# groqeval/metrics/toxicity.py
import json
from groq import Groq
from cachetools import cached, TTLCache
from groqeval.models.output import Output, ScoredOutput
from groqeval.metrics.base_metric import BaseMetric

Expand All @@ -16,6 +17,8 @@ def __init__(self, groq_client: Groq, output: str, prompt: str, **kwargs):
super().__init__(groq_client, kwargs.get('verbose'))
self.output = output
self.prompt = prompt
self.aggregation = max

self.check_data_types(prompt=prompt, output=output)


Expand Down Expand Up @@ -78,6 +81,7 @@ def output_decomposition(self):
self.logger.info("Breakdown of the Toxicity Score: %s", response.choices[0].message.content)
return Output.model_validate_json(response.choices[0].message.content)

@cached(cache=TTLCache(maxsize=100, ttl=300))
def score_toxicity(self):
"""
Each phrase is examined to see if it represents an opinion
Expand All @@ -100,16 +104,7 @@ def score_toxicity(self):
)
return ScoredOutput.model_validate_json(response.choices[0].message.content), json.loads(response.choices[0].message.content)

def score(self):
scored_output, output_dictionary = self.score_toxicity()
if scored_output.scores:
average_score = max([output.score for output in scored_output.scores])
return {
'score': average_score,
'score_breakdown': output_dictionary
}
else:
return {
'score': 0, # Default to 0 if there are no sentences to score
'score_breakdown': output_dictionary
}
@property
def scoring_function(self):
return self.score_toxicity

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ requires-python = ">=3.10"

dependencies = [
"groq>=0.9.0",
"pydantic>=2.7.4"
"pydantic>=2.7.4",
"cachetools>=5.3.3"
]

[tool.twine]
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
groq==0.9.0
pydantic==2.7.4
pydantic==2.7.4
cachetools==5.3.3
8 changes: 8 additions & 0 deletions tests/metrics/test_answer_relevance.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import math
import statistics

@pytest.mark.parametrize("prompt, output, expected_score_range", [
("Discuss the implications of quantum computing on data security.",
Expand All @@ -23,6 +24,13 @@ def test_answer_relevance_scoring(evaluator, prompt, output, expected_score_rang
result = answer_relevance.score()
assert math.ceil(result['score']) >= expected_score_range[0] and math.floor(result['score']) <= expected_score_range[1], f"Score {result['score']} not in range {expected_score_range}"

max_score = answer_relevance.score(max)['score']
mean_score = answer_relevance.score(statistics.mean)['score']
min_score = answer_relevance.score(min)['score']

assert max_score >= mean_score, f"Max score {max_score} is not greater than mean score {mean_score}"
assert min_score <= mean_score, f"Min score {min_score} is not less than mean score {mean_score}"

def test_answer_relevance_empty_output(evaluator):
prompt = "What are the benefits of meditation?"
output = ""
Expand Down
8 changes: 8 additions & 0 deletions tests/metrics/test_bias.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import math
import statistics

@pytest.mark.parametrize("prompt, output, expected_score_range", [
Expand All @@ -24,6 +25,13 @@ def test_bias_scoring(evaluator, prompt, output, expected_score_range):
result = bias.score()
assert math.ceil(result['score']) >= expected_score_range[0] and math.floor(result['score']) <= expected_score_range[1], f"Score {result['score']} not in range {expected_score_range}"

max_score = bias.score(max)['score']
mean_score = bias.score(statistics.mean)['score']
min_score = bias.score(min)['score']

assert max_score >= mean_score, f"Max score {max_score} is not greater than mean score {mean_score}"
assert min_score <= mean_score, f"Min score {min_score} is not less than mean score {mean_score}"

def test_bias_empty_output(evaluator):
prompt = "Evaluate the current role of renewable energy in economic development."
output = ""
Expand Down
9 changes: 9 additions & 0 deletions tests/metrics/test_context_relevance.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import math
import statistics

@pytest.mark.parametrize("prompt, context, expected_score_range", [
("Describe the impact of climate change on polar bears.",
Expand All @@ -23,6 +24,14 @@ def test_context_relevance_scoring(evaluator, prompt, context, expected_score_ra
result = context_relevance.score()
assert math.ceil(result['score']) >= expected_score_range[0] and math.floor(result['score']) <= expected_score_range[1], f"Score {result['score']} not in range {expected_score_range}"

max_score = context_relevance.score(max)['score']
mean_score = context_relevance.score(statistics.mean)['score']
min_score = context_relevance.score(min)['score']

assert max_score >= mean_score, f"Max score {max_score} is not greater than mean score {mean_score}"
assert min_score <= mean_score, f"Min score {min_score} is not less than mean score {mean_score}"


def test_context_relevance_empty_context(evaluator):
prompt = "What are the benefits of meditation?"
context = []
Expand Down
Loading

0 comments on commit 6cbd4f8

Please sign in to comment.