Skip to content

Commit

Permalink
✨ Add ChrF/++ metric
Browse files Browse the repository at this point in the history
  • Loading branch information
klh5 committed Jan 20, 2025
1 parent c8d60ef commit d0b92ca
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
26 changes: 26 additions & 0 deletions src/m4st/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,32 @@
from sonar.models.blaser.loader import load_blaser_model


class ChrFScore:
"""Applies ChrF/++ from the evaluate library.
When word_order=0 (default) computes original ChrF metric without including word
n-grams. When word_order=2, computes ChrF++. The DEMETR paper refers to ChrF++
as ChrF2.For more details see https://huggingface.co/spaces/evaluate-metric/chrf"""

def __init__(self, word_order: int = 0) -> None:
self.chrf = evaluate.load("chrf")
self.word_order = word_order

def get_scores(self, references: Series, predictions: Series) -> list:
results = []

for index, ref_txt in references.items():
mt_txt = predictions[index]
score = self.chrf.compute(
predictions=[mt_txt],
references=[[ref_txt]],
word_order=self.word_order,
eps_smoothing=True,
)
results.append(score["score"])

return results


class SacreBLEUScore:
"""Applies SacreBLEU from the evaluate library."""

Expand Down
11 changes: 11 additions & 0 deletions src/m4st/process_demetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from m4st.metrics import (
BLASERQEScore,
BLASERRefScore,
ChrFScore,
COMETQEScore,
COMETRefScore,
SacreBLEUScore,
Expand Down Expand Up @@ -57,6 +58,10 @@ def __init__(
self.comet_ref = COMETRefScore()
if "COMET_qe" in self.metrics_to_use:
self.comet_qe = COMETQEScore()
if "ChrF" in self.metrics_to_use:
self.chrf = ChrFScore(word_order=1)
if "ChrF2" in self.metrics_to_use:
self.chrf2 = ChrFScore(word_order=2)

print(f"Using metrics {self.metrics_to_use}")

Expand Down Expand Up @@ -114,6 +119,12 @@ def process_demetr_category(
elif metric == "SacreBLEU":
mt_results[:, j] = self.sacre_bleu.get_scores(ref_txts, mt_txts)
dis_results[:, j] = self.sacre_bleu.get_scores(ref_txts, dfluent_txts)
elif metric == "ChrF":
mt_results[:, j] = self.chrf.get_scores(ref_txts, mt_txts)
dis_results[:, j] = self.chrf.get_scores(ref_txts, dfluent_txts)
elif metric == "ChrF2":
mt_results[:, j] = self.chrf2.get_scores(ref_txts, mt_txts)
dis_results[:, j] = self.chrf2.get_scores(ref_txts, dfluent_txts)
else:
print(f"Unknown metric {metric}")

Expand Down

0 comments on commit d0b92ca

Please sign in to comment.