-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy patheval_exampleqa.py
64 lines (50 loc) · 1.91 KB
/
eval_exampleqa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
from ...llms.base_llm import BaseLLM
from ..base_evaluator import BaseEvaluator
from .dataset import ExampleQADataset
PROMPT_TEMPLATE = """You will be asked a series of questions. For each question, you need to provide a short answer. Do not provide any additional information.
Q: The capital of Germany is?
A: Berlin
Q: What is the largest planet in our solar system?
A: Jupiter
Q: How many continents are there on Earth? (answer in numeral)
A: 7
Q: What is 17+25? (answer in numeral)
A: 42
Q: {question}
A: """
class ExampleQAEvaluator(BaseEvaluator):
def __init__(
self, model: BaseLLM, num_batches: int = 1, output_dir: str = "./output"
):
super().__init__(model, num_batches, output_dir)
def set_generation_configs(self) -> None:
new_configs = {"max_new_tokens": 16, "do_sample": False}
self.model.update_generation_configs(new_configs)
def load_batched_dataset(self) -> list[list[dict]]:
dir = os.path.dirname(__file__)
filename = f"dataset_exampleqa.jsonl"
path = os.path.join(dir, filename)
dataset = ExampleQADataset(path)
batches = dataset.to_batched(self.num_batches)
return batches
def scoring(self, data_point: dict) -> dict:
query = PROMPT_TEMPLATE.format(question=data_point["q"])
response = self.model.safe_request(query)
answer = response.strip().split("\n")[0].strip() # Get the first line
return {
"metrics": {
"correct": answer == data_point["a"],
},
"log": {
"answer": answer,
"response": response,
},
"valid": answer != "",
}
def compute_overall(self, results: list[dict]) -> dict:
return {
"accuracy": sum([result["metrics"]["correct"] for result in results])
/ len(results),
"num": len(results),
}