-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c4601d8
commit 97b4b35
Showing
10 changed files
with
2,133 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .eval_instruct import evaluation_only |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,259 @@ | ||
from typing import Dict, List, Any, Optional, Generator | ||
import json | ||
import os | ||
import re | ||
import tempfile | ||
from pathlib import Path | ||
from tqdm import tqdm | ||
import logging | ||
|
||
from lm_eval.api.instance import Instance | ||
from lm_eval.api.model import LM | ||
from mbpp_plus.evaluation import evaluate_functional_correctness | ||
from utils.utils import extract_generation_code, language_settings | ||
from eval.task import BaseBenchmark | ||
|
||
|
||
class MBPPPlusBenchmark(BaseBenchmark): | ||
""" | ||
MBPPPlus benchmark for evaluating code generation capabilities across different languages. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
data_dir: str = "eval/chat_benchmarks/MBPPPlus/data", | ||
max_tokens: int = 1024, | ||
num_workers: int = 8, | ||
timeout: float = 3.0, | ||
debug: bool = False, | ||
logger: Optional[logging.Logger] = None, | ||
): | ||
""" | ||
Initialize MBPPPlus benchmark. | ||
Args: | ||
data_dir: Directory containing MBPPPlus datasets | ||
max_tokens: Maximum number of tokens for generation | ||
num_workers: Number of workers for parallel evaluation | ||
timeout: Timeout for code execution | ||
debug: If True, only evaluate first 2 examples | ||
logger: Optional logger instance | ||
""" | ||
super().__init__(logger) | ||
self.data_dir = data_dir | ||
self.max_tokens = max_tokens | ||
self.num_workers = num_workers | ||
self.timeout = timeout | ||
self.debug = debug | ||
self.num_examples = 3 | ||
self.start_idx = 0 | ||
self.end_idx = 500 | ||
|
||
def format_test_example(self, question: str, tests: List[str], code: Optional[str] = None) -> str: | ||
"""Format a single test example.""" | ||
prompt = ">>> Problem:\n{}\n>>> Test Cases:\n{}\n".format(question.strip(), "\n".join(tests)) | ||
if code: | ||
code = code.replace("\r", "").replace("\t", " ") | ||
prompt += "\n>>> Code:\n```python\n{}\n```".format(code) | ||
return prompt | ||
|
||
def read_test_examples(self, data_path: str) -> Generator[Dict[str, str], None, None]: | ||
""" | ||
Read and format test examples from data file. | ||
Args: | ||
data_path: Path to the data file | ||
Yields: | ||
Dictionary containing task_id and formatted prompt | ||
""" | ||
try: | ||
with open(data_path, "r") as f: | ||
examples = [json.loads(x) for x in f] | ||
self.logger.info(f"Loaded {len(examples)} examples from {data_path}") | ||
|
||
examples_str = [] | ||
for i in range(1, self.num_examples + 1): | ||
ex = examples[i] | ||
example_prompt = "- Example {}:\n{}".format( | ||
i, self.format_test_example(ex["prompt"], ex["test_list"], ex["code"]) | ||
) | ||
examples_str.append(example_prompt) | ||
|
||
eval_range = range(self.start_idx, min(self.end_idx, len(examples))) | ||
if self.debug: | ||
eval_range = list(eval_range)[:2] | ||
self.logger.info(f"Debug mode: using 2 examples") | ||
|
||
for i in eval_range: | ||
ex = examples[i] | ||
prompt = self.format_test_example(ex["prompt"], ex["test_list"]) | ||
|
||
prompt_with_shots = """ | ||
Please refer the given examples and generate a python function for my problem. | ||
Examples are listed as follows: | ||
{} | ||
Here is my problem: | ||
{} | ||
""".strip().format( | ||
"\n\n".join(examples_str), prompt | ||
) | ||
|
||
yield {"task_id": ex["task_id"], "prompt": prompt_with_shots} | ||
|
||
except Exception as e: | ||
self.logger.error(f"Error reading examples: {str(e)}") | ||
raise | ||
|
||
def extract_code(self, completion: str) -> str: | ||
"""Extract code block from model completion.""" | ||
try: | ||
code_block = re.findall(r"```python\n(.*?)```", completion, re.DOTALL | re.IGNORECASE)[0] | ||
return code_block | ||
except Exception as e: | ||
self.logger.warning(f"Failed to extract code block, using full completion.\nError: {str(e)}") | ||
return completion | ||
|
||
def generate_responses(self, model: LM) -> Dict[str, Any]: | ||
""" | ||
Generate code completions using the provided model. | ||
Args: | ||
model: Language model instance | ||
Returns: | ||
Dictionary containing generated responses and temporary directory, | ||
or None for non-primary ranks | ||
""" | ||
try: | ||
temp_dir_obj = tempfile.TemporaryDirectory() | ||
temp_dir = temp_dir_obj.name | ||
|
||
problem_file = os.path.join(self.data_dir, "mbppplus.jsonl") | ||
examples = list(self.read_test_examples(problem_file)) | ||
self.logger.info(f"Processing {len(examples)} examples") | ||
|
||
all_instances = [] | ||
for idx, example in enumerate(examples): | ||
try: | ||
inputs = model.apply_chat_template([{"role": "user", "content": example["prompt"]}]) | ||
|
||
all_instances.append( | ||
Instance( | ||
"generate_until", | ||
example, | ||
( | ||
inputs, | ||
{ | ||
"max_gen_toks": self.max_tokens, | ||
"do_sample": False, | ||
}, | ||
), | ||
idx, | ||
) | ||
) | ||
except Exception as e: | ||
self.logger.error(f"Error preparing instance {idx}: {str(e)}") | ||
continue | ||
|
||
self.logger.info("Generating responses for MBPPPlus...") | ||
outputs = self.compute(model, all_instances) | ||
|
||
# Return None early for non-primary ranks | ||
if model.rank != 0: | ||
return None | ||
|
||
generated_examples = [] | ||
for example, output in zip(examples, outputs): | ||
try: | ||
example_with_output = example.copy() | ||
example_with_output["gpt_completion"] = output | ||
example_with_output["generation"] = self.extract_code(output) | ||
generated_examples.append(example_with_output) | ||
except Exception as e: | ||
self.logger.error(f"Error processing output for {example['task_id']}: {str(e)}") | ||
continue | ||
|
||
output_path = os.path.join(temp_dir, "generated_python.jsonl") | ||
with open(output_path, "w", encoding="utf-8") as fw: | ||
for ex in generated_examples: | ||
fw.write(json.dumps(ex) + "\n") | ||
|
||
self.logger.info(f"Saved {len(generated_examples)} examples to {output_path}") | ||
|
||
return { | ||
"temp_dir_obj": temp_dir_obj, | ||
"num_examples": len(generated_examples), | ||
"total_examples": len(examples), | ||
} | ||
|
||
except Exception as e: | ||
self.logger.error(f"Error in generate_responses: {str(e)}") | ||
raise | ||
|
||
def evaluate_responses(self, results: Dict[str, Any]) -> Dict[str, float]: | ||
""" | ||
Evaluate the generated code completions. | ||
Args: | ||
results: Dictionary containing generation results | ||
Returns: | ||
Dictionary containing evaluation metrics | ||
""" | ||
# Handle None result from non-primary ranks | ||
if results is None: | ||
return None | ||
|
||
temp_dir_obj = results["temp_dir_obj"] | ||
temp_dir = temp_dir_obj.name | ||
|
||
evaluation_results = {} | ||
|
||
problem_file = os.path.join(self.data_dir, f"mbppplus.jsonl") | ||
temp_file_path = os.path.join(temp_dir, f"generated_python.jsonl") | ||
|
||
if not os.path.exists(temp_file_path): | ||
self.logger.warning(f"Generated file not found: {temp_file_path}") | ||
|
||
result = evaluate_functional_correctness( | ||
input_file=temp_file_path, | ||
tmp_dir=temp_dir, | ||
n_workers=self.num_workers, | ||
timeout=self.timeout, | ||
problem_file=problem_file, | ||
language="python", | ||
) | ||
|
||
for metric, value in result.items(): | ||
evaluation_results[f"{metric}"] = value | ||
|
||
self.logger.info(f"Completed evaluation") | ||
|
||
temp_dir_obj.cleanup() | ||
return evaluation_results | ||
|
||
def run_benchmark(self, model: LM) -> Dict[str, float]: | ||
""" | ||
Run the complete benchmark evaluation pipeline. | ||
Args: | ||
model: Language model instance | ||
Returns: | ||
Dictionary containing evaluation results, or None for non-primary ranks | ||
""" | ||
self.logger.info(f"Running MBPPPlus benchmark") | ||
try: | ||
generation_results = self.generate_responses(model) | ||
|
||
# If not primary rank, return None early | ||
if generation_results is None: | ||
return None | ||
|
||
evaluation_results = self.evaluate_responses(generation_results) | ||
return evaluation_results | ||
except Exception as e: | ||
self.logger.error(f"Error running benchmark: {str(e)}") | ||
return {"error": str(e)} |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from typing import Iterable, Dict | ||
import gzip | ||
import json | ||
import os | ||
|
||
|
||
def stream_jsonl(filename: str) -> Iterable[Dict]: | ||
""" | ||
Parses each jsonl line and yields it as a dictionary | ||
""" | ||
if filename.endswith(".gz"): | ||
with open(filename, "rb") as gzfp: | ||
with gzip.open(gzfp, "rt") as fp: | ||
for line in fp: | ||
if any(not x.isspace() for x in line): | ||
yield json.loads(line) | ||
else: | ||
with open(filename, "r", encoding="utf-8") as fp: | ||
for line in fp: | ||
if any(not x.isspace() for x in line): | ||
yield json.loads(line) | ||
|
||
|
||
def write_jsonl(filename: str, data: Iterable[Dict], append: bool = False): | ||
""" | ||
Writes an iterable of dictionaries to jsonl | ||
""" | ||
if append: | ||
mode = "ab" | ||
else: | ||
mode = "wb" | ||
filename = os.path.expanduser(filename) | ||
if filename.endswith(".gz"): | ||
with open(filename, mode) as fp: | ||
with gzip.GzipFile(fileobj=fp, mode="wb") as gzfp: | ||
for x in data: | ||
gzfp.write((json.dumps(x) + "\n").encode("utf-8")) | ||
else: | ||
with open(filename, mode) as fp: | ||
for x in data: | ||
fp.write((json.dumps(x) + "\n").encode("utf-8")) |
Oops, something went wrong.