Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Npfl 101 comet #75

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,16 @@ tot sacreBLEU docAsAWhole 32.786
avg sacreBLEU mwerSegmenter 25.850
```

If you want to calculate the COMET score as well, you need to include the ost file in the source language as src as shown below:
'''
MTeval -i sample-data/sample.en.cs.mt sample-data/sample.en.OSt sample-data/sample.cs.OSt -f mt src ref
'''
This would add an additional line in the output reporting the COMET score:
'''
tot COMET docAsWhole 0.770
'''
This is optional.

#### Evaluating SLT <a name="Evaluating-SLT"></a>

Spoken language translation evaluates "machine translation in time". So a time-stamped MT output (``slt``) is compared with the reference translation (non-timed, ``ref``) and the timing of the golden transcript (``ostt``).
Expand All @@ -293,6 +303,7 @@ tot Flicker count_changed_content 23
tot sacreBLEU docAsAWhole 32.786
...
```
Similar to MTeval, to calculate COMET score, you need to include the ost file in the source language.


#### Evaluating ASR <a name="Evaluating-ASR"></a>
Expand Down Expand Up @@ -392,6 +403,9 @@ Usage:
SLTIndexParser path_to_index_file path_to_dataset
```

5. It must be noted that a stable internet connection is necessary in order to download the COMET model to the local
system to calculate the COMET score.

## Terminology and Abbreviations <a name="Terminology-and-Abbreviations"></a>

* OSt ... original speech manually transcribed (i.e. golden transcript)
Expand Down
3 changes: 2 additions & 1 deletion SLTev/ASReval.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ def main(input_files=[], file_formats=[], arguments={}):
'ostt': read_ostt_file(gold_files["ostt"][0]),
'references': read_references(gold_files["ost"]),
'SLTev_home': SLTev_home,
'candidate': read_candidate_file(candidate_file[0])
'candidate': read_candidate_file(candidate_file[0]),
'src': read_references(gold_files["ost"])
}

_ = check_time_stamp_candiates_format(candidate_file[0], split_token) # submission checking
Expand Down
1 change: 1 addition & 0 deletions SLTev/MTeval.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def main(inputs=[], file_formats=[], arguments={}):
evaluation_object = {
'references': read_references(gold_files["ref"]),
'mt': read_candidate_file(candidate_file[0]),
'src': read_references(gold_files["src"]),
'SLTev_home': sltev_home,
}

Expand Down
5 changes: 4 additions & 1 deletion SLTev/SLTev.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def slt_submission_evaluation(args, inputs_object):


def build_input_fils_and_file_formats(submission_file, gold_input_files):
status, references, ostt, aligns = split_gold_inputs_submission_in_working_directory(submission_file, gold_input_files)
status, src, references, ostt, aligns = split_gold_inputs_submission_in_working_directory(submission_file, gold_input_files)

input_files = [submission_file]
file_formats = [remove_digits(status)]
Expand All @@ -375,6 +375,9 @@ def build_input_fils_and_file_formats(submission_file, gold_input_files):
if ostt != "":
input_files.append(ostt)
file_formats.append("ostt")
if src != "":
input_files.append(src)
file_formats.append("src")

for align_file in aligns:
input_files.append(align_file)
Expand Down
1 change: 1 addition & 0 deletions SLTev/SLTeval.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def main(input_files=[], file_formats=[], arguments={}):
'references': read_references(gold_files["ref"]),
'candidate': read_candidate_file(candidate_file[0]),
'align': gold_files["align"],
'src': read_references(gold_files["src"]),
'SLTev_home': sltev_home,
}

Expand Down
19 changes: 18 additions & 1 deletion SLTev/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from flicker_modules import calc_revise_count, calc_flicker_score
from flicker_modules import calc_average_flickers_per_sentence, calc_average_flickers_per_tokens
from quality_modules import calc_bleu_score_documentlevel, calc_bleu_score_segmenterlevel
from quality_modules import calc_bleu_score_timespanlevel
from quality_modules import calc_bleu_score_timespanlevel, calculate_comet_score
from utilities import mwerSegmenter_error_message, eprint
from files_modules import read_alignment_file

Expand Down Expand Up @@ -268,6 +268,8 @@ def normal_evaluation_without_parity(inputs_object):
references_statistical_info(references) # print statistical info
average_refernces_token_count = get_average_references_token_count(references)
candidate_sentences = inputs_object.get('candidate')
mt_sentences = inputs_object.get('mt')
src_file = inputs_object.get('src')

evaluation_object = {
'candidate_sentences': candidate_sentences,
Expand All @@ -283,6 +285,8 @@ def normal_evaluation_without_parity(inputs_object):
# bleu score evaluation
documantlevel_bleu_score_evaluation(references, candidate_sentences)
wordbased_segmenter_bleu_score_evaluation(evaluation_object)
if src_file != '' and src_file != []:
comet_score_evaluation(src_file, mt_sentences, references)

#flicker evaluation
print("tot Flicker count_changed_Tokens ", int(calc_revise_count(candidate_sentences)))
Expand Down Expand Up @@ -338,6 +342,7 @@ def normal_timestamp_evaluation(inputs_object):
average_refernces_token_count = get_average_references_token_count(references)
candidate_sentences = inputs_object.get('candidate')
OStt_sentences = inputs_object.get('ostt')
src_file = inputs_object.get('src')
print_ostt_duration(OStt_sentences)
Ts = []
for reference in references:
Expand Down Expand Up @@ -373,6 +378,8 @@ def normal_timestamp_evaluation(inputs_object):
documantlevel_bleu_score_evaluation(references, candidate_sentences)
wordbased_segmenter_bleu_score_evaluation(evaluation_object)
time_span_bleu_score_evaluation(evaluation_object)
if src_file != '' and src_file != []:
comet_score_evaluation(src_file, candidate_sentences, references)
#flicker evaluation
print("tot Flicker count_changed_Tokens ", int(calc_revise_count(candidate_sentences)))
print("tot Flicker count_changed_content ", int(calc_flicker_score(candidate_sentences)))
Expand All @@ -385,6 +392,13 @@ def normal_timestamp_evaluation(inputs_object):
str("{0:.3f}".format(round(calc_average_flickers_per_tokens(candidate_sentences), 3))),
)

def comet_score_evaluation(src_file, mt_sentences, references):
comet_score, success = calculate_comet_score(src_file, mt_sentences, references)
if success:
print(
"tot COMET docAsWhole ",
str("{0:.3f}".format(round(comet_score, 3))),
)

def simple_mt_evaluation(inputs_object):
current_path = os.getcwd()
Expand Down Expand Up @@ -413,6 +427,7 @@ def normal_mt_evaluation(inputs_object):
references_statistical_info(references) # print statistical info
average_refernces_token_count = get_average_references_token_count(references)
mt_sentences = inputs_object.get('mt')
src_file = inputs_object.get('src')

evaluation_object = {
'candidate_sentences': mt_sentences,
Expand All @@ -427,6 +442,8 @@ def normal_mt_evaluation(inputs_object):
# bleu score evaluation
documantlevel_bleu_score_evaluation(references, mt_sentences)
wordbased_segmenter_bleu_score_evaluation(evaluation_object)
if src_file != '' and src_file != []:
comet_score_evaluation(src_file, mt_sentences, references)



Expand Down
36 changes: 36 additions & 0 deletions SLTev/quality_modules.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#!/usr/bin/env python

import sacrebleu
from comet import download_model, load_from_checkpoint
from files_modules import quality_segmenter
from utilities import eprint


def calc_bleu_score_documentlevel(references, candiate_sentences):
Expand Down Expand Up @@ -162,3 +164,37 @@ def calc_bleu_score_timespanlevel(evaluation_object):
)
return bleu_scores, avg_SacreBleu

def calculate_comet_score(sources, candidates, references=None):
try:
model_name = "Unbabel/wmt22-comet-da"
merge_mt_sentences = []
for i in range(len(candidates)):
mt = candidates[i][-1][3:-1]
merge_mt_sentences += mt

merge_references_sentences = []
for ref in references:
l = []
for sentence in ref:
l.append(" ".join(sentence[:-1]))
merge_references_sentences.append(l)

merge_src_sentences = []
for src in sources:
l = []
for sentence in src:
l.append(" ".join(sentence[:-1]))
merge_src_sentences.append(l)

ref = [" ".join(i) for i in merge_references_sentences]
mt = [" ".join(merge_mt_sentences[:])]
src = [" ".join(i) for i in merge_src_sentences]
data = [{'src': x[0], 'mt': x[1], 'ref': x[2]} for x in zip(src, mt, ref)]

model_path = download_model(model_name)
model = load_from_checkpoint(model_path)
model_output = model.predict(data, batch_size=8, gpus=0)
return model_output['system_score'] * 100, True
except:
eprint("Unable to calculate COMET score since there is no internet connection to download the model.")
return 0, False
17 changes: 15 additions & 2 deletions SLTev/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def split_gold_inputs_submission_in_working_directory(submission_file, gold_inpu
:return tt, ostt, align: OSt, OStt, align files according to the submission file
"""

status, ostt = "", ""
status, ostt, src = "", "", ""
references, aligns = list(), list()

submission_file_name = os.path.split(submission_file)[1]
Expand All @@ -300,6 +300,11 @@ def split_gold_inputs_submission_in_working_directory(submission_file, gold_inpu
== submission_file_name_without_prefix + "." + target_lang + ".OSt"
):
references.append(file)
elif (
".".join(input_name[:-1]) + "." + remove_digits(input_name[-1])
== submission_file_name_without_prefix + "." + source_lang + ".OSt"
):
src = file
elif (
".".join(input_name[:-1]) + "." + remove_digits(input_name[-1])
== submission_file_name_without_prefix + "." + source_lang + ".OStt"
Expand All @@ -310,7 +315,7 @@ def split_gold_inputs_submission_in_working_directory(submission_file, gold_inpu
== submission_file_name_without_prefix + "." + source_lang + "." + target_lang + ".align"
):
aligns.append(file)
return status, references, ostt, aligns
return status, src, references, ostt, aligns


def mwerSegmenter_error_message():
Expand Down Expand Up @@ -459,6 +464,10 @@ def extract_mt_gold_files_for_candidate(candidate_file, gold_inputs):
except:
eprint( "evaluation failed, the reference file does not exist for ", candidate_file[0])
error = 1
try:
gold_files["src"] = gold_inputs["src"]
except:
gold_files["src"] = ""
return gold_files, error


Expand All @@ -479,6 +488,10 @@ def extract_slt_gold_files_for_candidate(candidate_file, gold_inputs):
gold_files["align"] = gold_inputs["align"]
except:
gold_files["align"] = []
try:
gold_files["src"] = gold_inputs["src"]
except:
gold_files["src"] = ""
return gold_files, error


Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ gitdir
jiwer
filelock
pytest
unbabel-comet==2.0.2
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"jiwer",
"filelock",
"pytest",
"unbabel-comet"
],
url="https://github.com/ELITR/SLTev.git",
classifiers=[
Expand Down