Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Chinese QA feature #83

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"flare_fiqasa": flare.FIQASA,
"flare_ner": flare.NER,
"flare_finqa": flare.FinQA,
"flare_auditqa_zh": flare.AuditQA,
"flare_convfinqa": flare.ConvFinQA,
"flare_headlines": flare.Headlines,
"flare_finer_ord": flare.FinerOrd,
Expand Down
133 changes: 133 additions & 0 deletions src/tasks/flare.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import re
from factscore_package.factscorer import FactScorer
import os
import jieba
from rouge_chinese import Rouge
#from comet import download_model, load_from_checkpoint

_CITATION = """
Expand Down Expand Up @@ -698,6 +700,130 @@ def aggregation(self):
}


class QA_zh(Task):
VERSION = 1
DATASET_NAME = None
EVAL_LAST_TURN = True

def reformulate_turn_req(self, req, turn_request, turn):
return req

def has_training_docs(self):
return False

def has_validation_docs(self):
return False

def has_test_docs(self):
return True

def training_docs(self):
return self.dataset["train"]

def validation_docs(self):
return self.dataset["validation"]

def test_docs(self):
return self.dataset["test"]

def doc_to_text(self, doc):
# TODO: Format the query prompt portion of the document example.
return doc["query"]

def doc_to_target(self, doc):
return doc["answer"]

def process_results(self, doc, results):
return {
"rouge1": (doc["answer"], results[0]),
"rouge2": (doc["answer"], results[0]),
"rougeL": (doc["answer"], results[0]),
"bert_score_f1": (doc["answer"], results[0]),
}

def higher_is_better(self):
return {
"rouge1": True,
"rouge2": True,
"rougeL": True,
"bert_score_f1": True,
}

def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.

:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
cont_request = rf.greedy_until(ctx, {"until": None})
return cont_request

def rouge_score_zh(self, items):
# Path to the dictionary for Chinese word segmentation
# Specifically designed for Chinese evaluation
# Set this to your local dictionary path.
USER_DICT_PATH = "/path/to/vocab.txt"
hyps, refs = map(list, zip(*[[' '.join(jieba.cut(d[0])), ' '.join(jieba.cut(d[1]))] for d in items]))
filter_hyps = []
filter_refs = []
for i in range(len(hyps)):
hyp = hyps[i]
ref = refs[i]
if self.is_whitespace_string(hyp) or self.is_whitespace_string(ref):
continue
if hyp != '' and ref != '':
filter_hyps.append(hyp)
filter_refs.append(ref)
rouge = Rouge()
scores = rouge.get_scores(filter_hyps, filter_refs, avg=True, ignore_empty=True)
return scores

def rouge1(self, items):
results = self.rouge_score_zh(items)
return results["rouge-1"]['f']

def rouge2(self, items):
results = self.rouge_score_zh(items)
return results["rouge-2"]['f']

def rougeL(self, items):
results = self.rouge_score_zh(items)
return results["rouge-l"]['f']

def is_whitespace_string(self, s):
return s.isspace()

def bert_score(self, items):
if getattr(self, "_cache_bertscore", None) is None:
golds, preds = zip(*items)
bertscore = evaluate.load("evaluate-metric/bertscore")
self._cache_bertscore = bertscore.compute(
predictions=preds,
references=golds,
model_type="bert-base-chinese",
)
return self._cache_bertscore
else:
return self._cache_bertscore

def bert_score_f1(self, items):
res = self.bert_score(items)
return sum(res["f1"]) / len(res["f1"])

def aggregation(self):
return {
"rouge1": self.rouge1,
"rouge2": self.rouge2,
"rougeL": self.rougeL,
"bert_score_f1": self.bert_score_f1,
}


class FPB(Classification):
DATASET_PATH = "chancefocus/flare-fpb"

Expand Down Expand Up @@ -787,6 +913,13 @@ class FinQA(QA):
DATASET_PATH = "chancefocus/flare-finqa"


class AuditQA(QA_zh):
# Path to the dataset for the AuditQA class
# Specifically designed for a Chinese dataset
# Set this to the appropriate path for the dataset
DATASET_PATH = "/path/to/dataset"


class StockMovement(Classification):
DATASET_NAME = None
CALCULATE_MCC = True
Expand Down