Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add content message to model quality text generation #220

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions api/app/models/metrics/model_quality_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ class TokenProb(BaseModel):

class TokenData(BaseModel):
id: str
message_content: str
probs: List[TokenProb]


Expand Down
1 change: 1 addition & 0 deletions api/tests/commons/db_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,7 @@ def get_sample_completion_dataset(
'tokens': [
{
'id': 'chatcmpl',
'message_content': 'Sky is blue.',
'probs': [
{'prob': 0.27718424797058105, 'token': 'Sky'},
{'prob': 0.8951022028923035, 'token': ' is'},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ class TokenProb(BaseModel):

class TokenData(BaseModel):
id: str
message_content: str
probs: List[TokenProb]


Expand Down
4 changes: 3 additions & 1 deletion sdk/tests/apis/model_completion_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def test_text_generation_model_quality_ok(self):
{
"tokens": [
{
"id":"chatcmpl",
"id": "chatcmpl",
"message_content": "Sky is blue.",
"probs":[
{
"prob":0.27,
Expand Down Expand Up @@ -84,6 +85,7 @@ def test_text_generation_model_quality_ok(self):
metrics = model_completion_dataset.model_quality()

assert isinstance(metrics, CompletionTextGenerationModelQuality)
assert metrics.tokens[0].message_content == 'Sky is blue.'
assert metrics.tokens[0].probs[0].prob == 0.27
assert metrics.tokens[0].probs[0].token == 'Sky'
assert metrics.mean_per_file[0].prob_tot_mean == 0.71
Expand Down
17 changes: 12 additions & 5 deletions spark/jobs/metrics/completion_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,25 @@ def remove_columns(df: DataFrame) -> DataFrame:
return df

def compute_prob(self, df: DataFrame):
df = df.select(F.explode("choices").alias("element"), F.col("id"))
df = df.select(
F.col("id"), F.explode("element.logprobs.content").alias("content")
F.explode("choices").alias("element"),
F.col("id"),
)
df = df.select("id", "content.logprob", "content.token").withColumn(
"prob", self.compute_probability_udf("logprob")
df = df.select(
F.col("id"),
F.col("element.message.content").alias("message_content"),
F.explode("element.logprobs.content").alias("content"),
)
df = df.select(
"id", "message_content", "content.logprob", "content.token"
).withColumn("prob", self.compute_probability_udf("logprob"))
return df

def extract_metrics(self, df: DataFrame) -> CompletionMetricsModel:
df = self.remove_columns(df)
df = self.compute_prob(df)
df_prob = df.drop("logprob")
df_prob = df_prob.groupBy("id").agg(
df_prob = df_prob.groupBy("id", "message_content").agg(
F.collect_list(F.struct("token", "prob")).alias("probs")
)
df_mean_values = df.groupBy("id").agg(
Expand All @@ -66,9 +71,11 @@ def extract_metrics(self, df: DataFrame) -> CompletionMetricsModel:
F.mean("prob_per_phrase").alias("prob_tot_mean"),
F.mean("perplex_per_phrase").alias("perplex_tot_mean"),
)
df_prob = df_prob.orderBy("id")
tokens = [
{
"id": row["id"],
"message_content": row["message_content"],
"probs": [
{"token": prob["token"], "prob": prob["prob"]}
for prob in row["probs"]
Expand Down
1 change: 1 addition & 0 deletions spark/jobs/models/completion_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class Prob(BaseModel):

class Probs(BaseModel):
id: str
message_content: str
probs: List[Prob]

model_config = ConfigDict(ser_json_inf_nan="null")
Expand Down
2 changes: 1 addition & 1 deletion spark/tests/completion_metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_compute_prob(spark_fixture, input_file):
completion_metrics_service = CompletionMetrics()
df = completion_metrics_service.remove_columns(input_file)
df = completion_metrics_service.compute_prob(df)
assert {"id", "logprob", "token", "prob"} == set(df.columns)
assert {"id", "logprob", "message_content", "token", "prob"} == set(df.columns)
assert not df.rdd.isEmpty()


Expand Down
2 changes: 2 additions & 0 deletions spark/tests/results/completion_metrics_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"tokens": [
{
"id": "chatcmpl-AcWID2SsE5iuK6z5AhNCKv3WUcCxN",
"message_content": "Sure, go ahead. What's up?",
"probs": [
{"token": "Sure", "prob": 0.541987419128418},
{"token": ",", "prob": 0.9025230407714844},
Expand All @@ -15,6 +16,7 @@
},
{
"id": "chatcmpl-AcYMMPLnpkksCdLze3M8nnqQbfqVG",
"message_content": "Certainly! Just let me know how.",
"probs": [
{"token": "Certainly", "prob": 0.022015240043401718},
{"token": "!", "prob": 0.8896080851554871},
Expand Down