diff --git a/src/coral_models/compute_metrics.py b/src/coral_models/compute_metrics.py index 9e252ecf..79e0a7e5 100644 --- a/src/coral_models/compute_metrics.py +++ b/src/coral_models/compute_metrics.py @@ -26,6 +26,7 @@ def compute_wer_metrics(pred: EvalPrediction, processor: Processor) -> dict[str, dictionary with 'wer' as the key and the word error rate as the value. """ wer_metric = load_metric("wer") + cer_metric = load_metric("cer") tokenizer: PreTrainedTokenizerBase = getattr(processor, "tokenizer") pad_token = tokenizer.pad_token_id @@ -84,13 +85,26 @@ def compute_wer_metrics(pred: EvalPrediction, processor: Processor) -> dict[str, logger.info(f"Sample document: {labels_str[random_idx]}") logger.info(f"Predicted: {predictions_str[random_idx]}") - # Compute the word error rate - computed = wer_metric.compute(predictions=predictions_str, references=labels_str) - assert computed is not None + metrics: dict[str, float] = dict() - # Ensure that `wer` is a dict, as metrics in the `evaluate` library can either be - # dicts or floats - if not isinstance(computed, dict): - return dict(wer=computed) + # Compute the word error rate + wer_computed = wer_metric.compute( + predictions=predictions_str, references=labels_str + ) + assert wer_computed is not None + if not isinstance(wer_computed, dict): + metrics = metrics | dict(wer=wer_computed) else: - return computed + metrics = metrics | wer_computed + + # Compute the character error rate + cer_computed = cer_metric.compute( + predictions=predictions_str, references=labels_str + ) + assert cer_computed is not None + if not isinstance(cer_computed, dict): + metrics = metrics | dict(cer=cer_computed) + else: + metrics = metrics | cer_computed + + return metrics