-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #10 from alan-turing-institute/8-replicate-demetr-…
…results-for-bleu 8 replicate demetr results for bleu
- Loading branch information
Showing
5 changed files
with
393 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import argparse | ||
import os | ||
|
||
from m4st.process_demetr import ProcessDEMETR | ||
|
||
|
||
def main(args: dict) -> None: | ||
output_dir = args["output_dir"] | ||
output_file = args["output_file"] | ||
|
||
os.makedirs(output_dir, exist_ok=True) | ||
|
||
demetr = ProcessDEMETR( | ||
metrics_to_use=args["metrics"], | ||
output_filepath=os.path.join(output_dir, output_file), | ||
demetr_root=args["dataset_dir"], | ||
) | ||
|
||
print(args["cats"]) | ||
demetr.process_demetr(cats_to_process=args["cats"]) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument( | ||
"--dataset-dir", | ||
type=str, | ||
default="../../datasets/demetr", | ||
help="Root dataset \ | ||
for DEMETR containing JSON files.", | ||
) | ||
parser.add_argument( | ||
"--output-dir", | ||
type=str, | ||
default="../../outputs/demetr", | ||
help="Path to output directory. Will be created by script.", | ||
) | ||
parser.add_argument( | ||
"--output-file", | ||
type=str, | ||
default="demetr_results.csv", | ||
help="Name for output CSV file.", | ||
) | ||
parser.add_argument( | ||
"--metrics", | ||
nargs="+", | ||
type=str, | ||
default=["COMET_ref", "COMET_qe", "BLASER_ref", "BLASER_qe", "SacreBLEU"], | ||
help="Metrics to use. Must be one or more \ | ||
of COMET_ref, COMET_qe, BLASER_ref, BLASER_qe, SacreBLEU. Defaults to all.", | ||
) | ||
parser.add_argument( | ||
"--cats", | ||
nargs="+", | ||
type=int, | ||
required=False, | ||
help="Specific DEMETR disfluency \ | ||
categories to be processed. By default all will be processsed.", | ||
) | ||
|
||
args = parser.parse_args() | ||
main(vars(args)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
import evaluate | ||
import numpy as np | ||
from pandas import Series | ||
from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline | ||
from sonar.models.blaser.loader import load_blaser_model | ||
|
||
|
||
class SacreBLEUScore: | ||
"""Applies SacreBLEU from the evaluate library.""" | ||
|
||
def __init__(self) -> None: | ||
self.bleu = evaluate.load("sacrebleu") | ||
|
||
def get_scores(self, references: Series, predictions: Series) -> list: | ||
results = [] | ||
|
||
# SacreBLEU doesn't seem to support batching that isn't document-level, so | ||
# each sentence must be run through separately | ||
for index, ref_txt in references.items(): | ||
mt_txt = predictions[index] | ||
score = self.bleu.compute(predictions=[mt_txt], references=[[ref_txt]]) | ||
results.append(score["score"]) | ||
|
||
return results | ||
|
||
|
||
class BLASERRefScore: | ||
"""Initialises and applies the BLASER 2.0 QE metric from the SONAR library.""" | ||
|
||
def __init__(self, ref_lang_code: str = "eng_Latn") -> None: | ||
self.blaser_ref = load_blaser_model("blaser_2_0_ref").eval() | ||
self.text_embedder = TextToEmbeddingModelPipeline( | ||
encoder="text_sonar_basic_encoder", tokenizer="text_sonar_basic_encoder" | ||
) | ||
# Code defining the target language | ||
# Defaults to English | ||
self.ref_lang_code = ref_lang_code | ||
|
||
def get_scores( | ||
self, | ||
references: Series, | ||
predictions: Series, | ||
sources: Series, | ||
source_lang_codes: Series, | ||
) -> list: | ||
langs = np.unique(source_lang_codes) | ||
|
||
# Store results for all languages so they can be returned together | ||
results = [] | ||
|
||
# BLASER requires the source language, so at best we can batch by language as | ||
# source_lang must be a string | ||
for language in langs: | ||
mask = source_lang_codes == language | ||
sources_lang = np.array(sources[mask]) | ||
refs_lang = np.array(references[mask]) | ||
preds_lang = np.array(predictions[mask]) | ||
|
||
src_embs = self.text_embedder.predict(sources_lang, source_lang=language) | ||
ref_embs = self.text_embedder.predict( | ||
refs_lang, source_lang=self.ref_lang_code | ||
) | ||
mt_embs = self.text_embedder.predict( | ||
preds_lang, source_lang=self.ref_lang_code | ||
) | ||
|
||
for i in range(len(src_embs)): | ||
result = self.blaser_ref( | ||
src=src_embs[[i]], ref=ref_embs[[i]], mt=mt_embs[[i]] | ||
).item() | ||
results.append(result) | ||
|
||
return results | ||
|
||
|
||
class BLASERQEScore: | ||
"""Initialises and applies the BLASER 2.0 reference-based metric from the SONAR | ||
library.""" | ||
|
||
def __init__(self, ref_lang_code: str = "eng_Latn") -> None: | ||
self.blaser_qe = load_blaser_model("blaser_2_0_qe").eval() | ||
self.text_embedder = TextToEmbeddingModelPipeline( | ||
encoder="text_sonar_basic_encoder", tokenizer="text_sonar_basic_encoder" | ||
) | ||
# Code defining the target language | ||
# Defaults to English | ||
self.ref_lang_code = ref_lang_code | ||
|
||
def get_scores( | ||
self, predictions: Series, sources: Series, source_lang_codes: Series | ||
) -> list: | ||
langs = np.unique(source_lang_codes) | ||
|
||
# Store results for all languages so they can be returned together | ||
results = [] | ||
|
||
# BLASER requires the source language, so at best we can batch by language as | ||
# source_lang must be a string | ||
for language in langs: | ||
mask = source_lang_codes == language | ||
sources_lang = np.array(sources[mask]) | ||
preds_lang = np.array(predictions[mask]) | ||
|
||
src_embs = self.text_embedder.predict(sources_lang, source_lang=language) | ||
mt_embs = self.text_embedder.predict( | ||
preds_lang, source_lang=self.ref_lang_code | ||
) | ||
|
||
for i in range(len(src_embs)): | ||
result = self.blaser_qe(src=src_embs[[i]], mt=mt_embs[[i]]).item() | ||
results.append(result) | ||
|
||
return results | ||
|
||
|
||
class COMETRefScore: | ||
"""Applies COMET reference-based metric from the evaluate library.""" | ||
|
||
def __init__(self) -> None: | ||
self.comet = evaluate.load("comet", model="wmt21-comet-mqm") | ||
|
||
def get_scores( | ||
self, references: Series, predictions: Series, sources: Series | ||
) -> list: | ||
scores = self.comet.compute( | ||
predictions=predictions, | ||
references=references, | ||
sources=sources, | ||
) | ||
return scores["scores"] | ||
|
||
|
||
class COMETQEScore: | ||
"""Applies COMET QE metric from the evaluate library.""" | ||
|
||
def __init__(self) -> None: | ||
self.comet = evaluate.load("comet", model="wmt21-comet-qe-mqm") | ||
|
||
def get_scores( | ||
self, references: Series, predictions: Series, sources: Series | ||
) -> list: | ||
scores = self.comet.compute( | ||
predictions=predictions, references=references, sources=sources | ||
) | ||
return scores["scores"] |
Oops, something went wrong.