Skip to content

Commit

Permalink
Merge pull request #56 from alexandrainst/feat/add-cer-metric
Browse files Browse the repository at this point in the history
  • Loading branch information
saattrupdan authored Dec 14, 2023
2 parents d9d09de + 3f0dbf6 commit 1eebe7a
Showing 1 changed file with 22 additions and 8 deletions.
30 changes: 22 additions & 8 deletions src/coral_models/compute_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def compute_wer_metrics(pred: EvalPrediction, processor: Processor) -> dict[str,
dictionary with 'wer' as the key and the word error rate as the value.
"""
wer_metric = load_metric("wer")
cer_metric = load_metric("cer")
tokenizer: PreTrainedTokenizerBase = getattr(processor, "tokenizer")
pad_token = tokenizer.pad_token_id

Expand Down Expand Up @@ -84,13 +85,26 @@ def compute_wer_metrics(pred: EvalPrediction, processor: Processor) -> dict[str,
logger.info(f"Sample document: {labels_str[random_idx]}")
logger.info(f"Predicted: {predictions_str[random_idx]}")

# Compute the word error rate
computed = wer_metric.compute(predictions=predictions_str, references=labels_str)
assert computed is not None
metrics: dict[str, float] = dict()

# Ensure that `wer` is a dict, as metrics in the `evaluate` library can either be
# dicts or floats
if not isinstance(computed, dict):
return dict(wer=computed)
# Compute the word error rate
wer_computed = wer_metric.compute(
predictions=predictions_str, references=labels_str
)
assert wer_computed is not None
if not isinstance(wer_computed, dict):
metrics = metrics | dict(wer=wer_computed)
else:
return computed
metrics = metrics | wer_computed

# Compute the character error rate
cer_computed = cer_metric.compute(
predictions=predictions_str, references=labels_str
)
assert cer_computed is not None
if not isinstance(cer_computed, dict):
metrics = metrics | dict(cer=cer_computed)
else:
metrics = metrics | cer_computed

return metrics

0 comments on commit 1eebe7a

Please sign in to comment.