From d0b92ca819b1384cf886cf8b13ff22f669e0a1c4 Mon Sep 17 00:00:00 2001 From: klh5 Date: Mon, 20 Jan 2025 10:42:18 +0000 Subject: [PATCH] :sparkles: Add ChrF/++ metric --- src/m4st/metrics.py | 26 ++++++++++++++++++++++++++ src/m4st/process_demetr.py | 11 +++++++++++ 2 files changed, 37 insertions(+) diff --git a/src/m4st/metrics.py b/src/m4st/metrics.py index 5d3f455..6d04161 100644 --- a/src/m4st/metrics.py +++ b/src/m4st/metrics.py @@ -5,6 +5,32 @@ from sonar.models.blaser.loader import load_blaser_model +class ChrFScore: + """Applies ChrF/++ from the evaluate library. + When word_order=0 (default) computes original ChrF metric without including word + n-grams. When word_order=2, computes ChrF++. The DEMETR paper refers to ChrF++ + as ChrF2.For more details see https://huggingface.co/spaces/evaluate-metric/chrf""" + + def __init__(self, word_order: int = 0) -> None: + self.chrf = evaluate.load("chrf") + self.word_order = word_order + + def get_scores(self, references: Series, predictions: Series) -> list: + results = [] + + for index, ref_txt in references.items(): + mt_txt = predictions[index] + score = self.chrf.compute( + predictions=[mt_txt], + references=[[ref_txt]], + word_order=self.word_order, + eps_smoothing=True, + ) + results.append(score["score"]) + + return results + + class SacreBLEUScore: """Applies SacreBLEU from the evaluate library.""" diff --git a/src/m4st/process_demetr.py b/src/m4st/process_demetr.py index de8779d..69eb3ee 100644 --- a/src/m4st/process_demetr.py +++ b/src/m4st/process_demetr.py @@ -11,6 +11,7 @@ from m4st.metrics import ( BLASERQEScore, BLASERRefScore, + ChrFScore, COMETQEScore, COMETRefScore, SacreBLEUScore, @@ -57,6 +58,10 @@ def __init__( self.comet_ref = COMETRefScore() if "COMET_qe" in self.metrics_to_use: self.comet_qe = COMETQEScore() + if "ChrF" in self.metrics_to_use: + self.chrf = ChrFScore(word_order=1) + if "ChrF2" in self.metrics_to_use: + self.chrf2 = ChrFScore(word_order=2) print(f"Using metrics {self.metrics_to_use}") @@ -114,6 +119,12 @@ def process_demetr_category( elif metric == "SacreBLEU": mt_results[:, j] = self.sacre_bleu.get_scores(ref_txts, mt_txts) dis_results[:, j] = self.sacre_bleu.get_scores(ref_txts, dfluent_txts) + elif metric == "ChrF": + mt_results[:, j] = self.chrf.get_scores(ref_txts, mt_txts) + dis_results[:, j] = self.chrf.get_scores(ref_txts, dfluent_txts) + elif metric == "ChrF2": + mt_results[:, j] = self.chrf2.get_scores(ref_txts, mt_txts) + dis_results[:, j] = self.chrf2.get_scores(ref_txts, dfluent_txts) else: print(f"Unknown metric {metric}")