From aa471914f8e0107d866fc534fc5396602788f10e Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Thu, 23 Jan 2025 22:39:10 +0000 Subject: [PATCH 01/47] refactoring init Signed-off-by: SumanthRH --- skythought/tools/inference_and_check.py | 1 - skythought/tools/response_rewrite.py | 2 +- skythought/tools/tasks/__init__.py | 0 skythought/tools/tasks/aime/aime.yaml | 0 skythought/tools/tasks/aime/aime_handler.py | 0 skythought/tools/tasks/apps/apps.yaml | 0 skythought/tools/tasks/apps/apps_handler.py | 0 .../apps/testing_util.py => tasks/apps/apps_util.py} | 0 skythought/tools/tasks/livecodebench/livecodebench.yaml | 0 .../tools/tasks/livecodebench/livecodebench_handler.py | 0 .../livecodebench/livecodebench_util.py} | 0 skythought/tools/tasks/math500/math500.yaml | 0 skythought/tools/tasks/math500/math500_handler.py | 0 skythought/tools/{util => tasks}/taco/pyext2.py | 0 skythought/tools/tasks/taco/taco.yaml | 0 skythought/tools/tasks/taco/taco_handler.py | 0 .../taco/testing_util.py => tasks/taco/taco_util.py} | 0 .../util/{math/testing_util.py => math_parsing_util.py} | 0 skythought/tools/util/task_handlers.py | 8 ++++---- 19 files changed, 5 insertions(+), 6 deletions(-) create mode 100644 skythought/tools/tasks/__init__.py create mode 100644 skythought/tools/tasks/aime/aime.yaml create mode 100644 skythought/tools/tasks/aime/aime_handler.py create mode 100644 skythought/tools/tasks/apps/apps.yaml create mode 100644 skythought/tools/tasks/apps/apps_handler.py rename skythought/tools/{util/apps/testing_util.py => tasks/apps/apps_util.py} (100%) create mode 100644 skythought/tools/tasks/livecodebench/livecodebench.yaml create mode 100644 skythought/tools/tasks/livecodebench/livecodebench_handler.py rename skythought/tools/{util/livecodebench/testing_util.py => tasks/livecodebench/livecodebench_util.py} (100%) create mode 100644 skythought/tools/tasks/math500/math500.yaml create mode 100644 skythought/tools/tasks/math500/math500_handler.py rename skythought/tools/{util => tasks}/taco/pyext2.py (100%) create mode 100644 skythought/tools/tasks/taco/taco.yaml create mode 100644 skythought/tools/tasks/taco/taco_handler.py rename skythought/tools/{util/taco/testing_util.py => tasks/taco/taco_util.py} (100%) rename skythought/tools/util/{math/testing_util.py => math_parsing_util.py} (100%) diff --git a/skythought/tools/inference_and_check.py b/skythought/tools/inference_and_check.py index c716ada..587e590 100644 --- a/skythought/tools/inference_and_check.py +++ b/skythought/tools/inference_and_check.py @@ -48,7 +48,6 @@ def perform_inference_and_check(handler: TaskHandler, temperatures, max_tokens, filter_difficulty=args.filter_difficulty, args=args) remaining_data = handler.process_remaining_data(train_data, results) conversations = handler.make_conversations(remaining_data, system_prompt, args.model) - for temp in temperatures: if args.model.startswith("openai"): diff --git a/skythought/tools/response_rewrite.py b/skythought/tools/response_rewrite.py index 1d9be62..6dff3a6 100644 --- a/skythought/tools/response_rewrite.py +++ b/skythought/tools/response_rewrite.py @@ -3,7 +3,7 @@ import os import random from tqdm import tqdm -from util.math.testing_util import strip_answer_string +from skythought.tools.util.math_parsing_util import strip_answer_string from util.model_utils import * from vllm import LLM, SamplingParams diff --git a/skythought/tools/tasks/__init__.py b/skythought/tools/tasks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/skythought/tools/tasks/aime/aime.yaml b/skythought/tools/tasks/aime/aime.yaml new file mode 100644 index 0000000..e69de29 diff --git a/skythought/tools/tasks/aime/aime_handler.py b/skythought/tools/tasks/aime/aime_handler.py new file mode 100644 index 0000000..e69de29 diff --git a/skythought/tools/tasks/apps/apps.yaml b/skythought/tools/tasks/apps/apps.yaml new file mode 100644 index 0000000..e69de29 diff --git a/skythought/tools/tasks/apps/apps_handler.py b/skythought/tools/tasks/apps/apps_handler.py new file mode 100644 index 0000000..e69de29 diff --git a/skythought/tools/util/apps/testing_util.py b/skythought/tools/tasks/apps/apps_util.py similarity index 100% rename from skythought/tools/util/apps/testing_util.py rename to skythought/tools/tasks/apps/apps_util.py diff --git a/skythought/tools/tasks/livecodebench/livecodebench.yaml b/skythought/tools/tasks/livecodebench/livecodebench.yaml new file mode 100644 index 0000000..e69de29 diff --git a/skythought/tools/tasks/livecodebench/livecodebench_handler.py b/skythought/tools/tasks/livecodebench/livecodebench_handler.py new file mode 100644 index 0000000..e69de29 diff --git a/skythought/tools/util/livecodebench/testing_util.py b/skythought/tools/tasks/livecodebench/livecodebench_util.py similarity index 100% rename from skythought/tools/util/livecodebench/testing_util.py rename to skythought/tools/tasks/livecodebench/livecodebench_util.py diff --git a/skythought/tools/tasks/math500/math500.yaml b/skythought/tools/tasks/math500/math500.yaml new file mode 100644 index 0000000..e69de29 diff --git a/skythought/tools/tasks/math500/math500_handler.py b/skythought/tools/tasks/math500/math500_handler.py new file mode 100644 index 0000000..e69de29 diff --git a/skythought/tools/util/taco/pyext2.py b/skythought/tools/tasks/taco/pyext2.py similarity index 100% rename from skythought/tools/util/taco/pyext2.py rename to skythought/tools/tasks/taco/pyext2.py diff --git a/skythought/tools/tasks/taco/taco.yaml b/skythought/tools/tasks/taco/taco.yaml new file mode 100644 index 0000000..e69de29 diff --git a/skythought/tools/tasks/taco/taco_handler.py b/skythought/tools/tasks/taco/taco_handler.py new file mode 100644 index 0000000..e69de29 diff --git a/skythought/tools/util/taco/testing_util.py b/skythought/tools/tasks/taco/taco_util.py similarity index 100% rename from skythought/tools/util/taco/testing_util.py rename to skythought/tools/tasks/taco/taco_util.py diff --git a/skythought/tools/util/math/testing_util.py b/skythought/tools/util/math_parsing_util.py similarity index 100% rename from skythought/tools/util/math/testing_util.py rename to skythought/tools/util/math_parsing_util.py diff --git a/skythought/tools/util/task_handlers.py b/skythought/tools/util/task_handlers.py index 0974ccf..4f33eb5 100644 --- a/skythought/tools/util/task_handlers.py +++ b/skythought/tools/util/task_handlers.py @@ -8,10 +8,10 @@ from datasets import load_dataset from typing import Dict, Any from multiprocessing import Manager -from .apps.testing_util import run_test as apps_run_test -from .taco.testing_util import run_test as taco_run_test -from .math.testing_util import strip_answer_string, get_multiple_choice_answer, extract_answer, math_equal, mmlu_pro_extract_answer -from .livecodebench.testing_util import unsafe_lcb_runTests, map_to_example, has_test_type, post_process_code, translate_private_test_cases +from tasks.apps.apps_util import run_test as apps_run_test +from tasks.taco.taco_util import run_test as taco_run_test +from .math_parsing_util import strip_answer_string, get_multiple_choice_answer, extract_answer, math_equal, mmlu_pro_extract_answer +from tasks.livecodebench.livecodebench_util import unsafe_lcb_runTests, map_to_example, has_test_type, post_process_code, translate_private_test_cases from .common import TimeoutException, timeout from util.model_utils import * From ee8d0f4b2a101d6dd4d9935ea515af7ea0763068 Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Fri, 24 Jan 2025 00:46:17 +0000 Subject: [PATCH 02/47] add a bunch of stuff Signed-off-by: SumanthRH --- skythought/tools/.githooks/pre-commit | 5 + skythought/tools/.pre-commit-config.yaml | 12 + skythought/tools/combine_data.py | 23 +- skythought/tools/convert_format.py | 51 +- skythought/tools/convert_to_data.py | 26 +- skythought/tools/eval.py | 85 +- skythought/tools/format.sh | 14 + skythought/tools/inference_and_check.py | 373 +++++++-- skythought/tools/label_math_difficulty.py | 47 +- skythought/tools/pyproject.toml | 3 + skythought/tools/response_rewrite.py | 774 ++++++++++-------- skythought/tools/tasks/aime/aime_handler.py | 40 + .../math500.yaml => amc23/amc23.yaml} | 0 skythought/tools/tasks/amc23/amc23_handler.py | 24 + skythought/tools/tasks/apps/apps_handler.py | 128 +++ skythought/tools/tasks/apps/apps_util.py | 293 ++++--- skythought/tools/tasks/common.py | 210 +++++ .../gpqa_diamond/gpqa_diamond_handler.py | 90 ++ skythought/tools/tasks/gsm8k/gsm8k_handler.py | 100 +++ .../livecodebench/livecodebench_handler.py | 137 ++++ .../tasks/livecodebench/livecodebench_util.py | 93 ++- .../math500_handler.py => math/math500.yaml} | 0 .../tools/tasks/math/math500_handler.py | 10 + skythought/tools/tasks/math/math_handler.py | 69 ++ skythought/tools/tasks/mmlu/mmlu_handler.py | 113 +++ .../tools/tasks/numina/numina_handler.py | 83 ++ skythought/tools/tasks/taco/pyext2.py | 486 +++++++---- skythought/tools/tasks/taco/taco_handler.py | 124 +++ skythought/tools/tasks/taco/taco_util.py | 323 +++++--- skythought/tools/upload_hub.py | 5 +- skythought/tools/util/common.py | 16 +- skythought/tools/util/math_parsing_util.py | 37 +- skythought/tools/util/model_utils.py | 4 +- skythought/tools/util/prompts.py | 8 +- skythought/tools/util/task_handlers.py | 579 ++++++++----- 35 files changed, 3255 insertions(+), 1130 deletions(-) create mode 100644 skythought/tools/.githooks/pre-commit create mode 100644 skythought/tools/.pre-commit-config.yaml create mode 100644 skythought/tools/format.sh create mode 100644 skythought/tools/pyproject.toml rename skythought/tools/tasks/{math500/math500.yaml => amc23/amc23.yaml} (100%) create mode 100644 skythought/tools/tasks/amc23/amc23_handler.py create mode 100644 skythought/tools/tasks/common.py create mode 100644 skythought/tools/tasks/gpqa_diamond/gpqa_diamond_handler.py create mode 100644 skythought/tools/tasks/gsm8k/gsm8k_handler.py rename skythought/tools/tasks/{math500/math500_handler.py => math/math500.yaml} (100%) create mode 100644 skythought/tools/tasks/math/math500_handler.py create mode 100644 skythought/tools/tasks/math/math_handler.py create mode 100644 skythought/tools/tasks/mmlu/mmlu_handler.py create mode 100644 skythought/tools/tasks/numina/numina_handler.py diff --git a/skythought/tools/.githooks/pre-commit b/skythought/tools/.githooks/pre-commit new file mode 100644 index 0000000..094125f --- /dev/null +++ b/skythought/tools/.githooks/pre-commit @@ -0,0 +1,5 @@ +# Only run pre-commit if changes are in tools/ +if git diff --cached --name-only | grep "^tools/"; then + cd skythought/tools/ + pre-commit run --files $(git diff --cached --name-only | grep "^tools/") +fi \ No newline at end of file diff --git a/skythought/tools/.pre-commit-config.yaml b/skythought/tools/.pre-commit-config.yaml new file mode 100644 index 0000000..003a141 --- /dev/null +++ b/skythought/tools/.pre-commit-config.yaml @@ -0,0 +1,12 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.6.9 + hooks: + - id: ruff + args: [ --fix, --exit-non-zero-on-fix ] + + # Black needs to be ran after ruff with --fix + - repo: https://github.com/psf/black + rev: 23.3.0 + hooks: + - id: black \ No newline at end of file diff --git a/skythought/tools/combine_data.py b/skythought/tools/combine_data.py index 492788b..2ab86e9 100644 --- a/skythought/tools/combine_data.py +++ b/skythought/tools/combine_data.py @@ -1,5 +1,6 @@ import json import random + from util.prompts import system_prompt still2_jsonl_file = "../../data/public_long_form_thought_data_5k.jsonl" @@ -25,14 +26,11 @@ # Create the conversation format conversations = [ {"from": "user", "value": question}, - {"from": "assistant", "value": combined_text} + {"from": "assistant", "value": combined_text}, ] - + # Prepare the final structure - cur_data = { - "system": system_prompt, - "conversations": conversations - } + cur_data = {"system": system_prompt, "conversations": conversations} all_data.append(cur_data) else: code_num += 1 @@ -43,14 +41,19 @@ # print(code_data[0]) all_data.extend(code_data) -print(f"First item slice before shuffle: {all_data[0]['conversations'][-1]['value'][-50:-1]}") +print( + f"First item slice before shuffle: {all_data[0]['conversations'][-1]['value'][-50:-1]}" +) random.shuffle(all_data) -print(f"First item slice after shuffle: {all_data[0]['conversations'][-1]['value'][-50:-1]}") +print( + f"First item slice after shuffle: {all_data[0]['conversations'][-1]['value'][-50:-1]}" +) print(len(all_data)) # Save the converted data to the output file with open(output_file, "w") as f: json.dump(all_data, f, indent=4) -print(f"Conversion completed. The data has been saved to {output_file} with {len(all_data)} data.") - +print( + f"Conversion completed. The data has been saved to {output_file} with {len(all_data)} data." +) diff --git a/skythought/tools/convert_format.py b/skythought/tools/convert_format.py index fe02477..ead9d3a 100644 --- a/skythought/tools/convert_format.py +++ b/skythought/tools/convert_format.py @@ -1,23 +1,28 @@ -import json import argparse -from tqdm import tqdm +import json import multiprocessing as mp -import openai -from itertools import cycle -import time import os +import time +from itertools import cycle + +import openai +from tqdm import tqdm + from util.prompts import convert_prompt, convert_prompt_example global args + + # Function to set the OpenAI API key def set_openai_key(api_key): openai.api_key = api_key + # GPT API processing function with retry logic def process_content(content, api_key): # Set the OpenAI key for this request set_openai_key(api_key) - + # GPT prompt prompt = convert_prompt.format(example=convert_prompt_example, content=content) @@ -28,44 +33,54 @@ def process_content(content, api_key): response = openai.chat.completions.create( model="gpt-4o-mini", messages=[ - {"role": "system", "content": "You are a solution format convertor."}, - {"role": "user", "content": prompt} + { + "role": "system", + "content": "You are a solution format convertor.", + }, + {"role": "user", "content": prompt}, ], max_tokens=16384, - temperature=0.7 + temperature=0.7, ) return response.choices[0].message.content except openai.RateLimitError: retries -= 1 if retries == 0: return "Error: Rate limit reached and retries exhausted." - print(f"Sleep for 5 seconds for API limit.") + print("Sleep for 5 seconds for API limit.") time.sleep(5) except Exception as e: return f"Error processing content: {e}" + # Function for multiprocessing def process_entry(entry, api_key_cycle): key, values = entry content = values["responses"]["0.7"]["content"] - + # Get the next API key from the cycle api_key = next(api_key_cycle) - + processed = process_content(content, api_key) values["responses"]["0.7"]["processed_content"] = processed - + return key, values + # Wrapper function for multiprocessing def process_entry_wrapper(args): return process_entry(*args) + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Process content and save results.") - parser.add_argument("--input_dir", type=str, help="Input directory containing JSON files.") - parser.add_argument("--keys", type=str, help="File containing OpenAI API keys (one per line).") - + parser.add_argument( + "--input_dir", type=str, help="Input directory containing JSON files." + ) + parser.add_argument( + "--keys", type=str, help="File containing OpenAI API keys (one per line)." + ) + global args args = parser.parse_args() @@ -90,7 +105,9 @@ def process_entry_wrapper(args): results = [] with mp.Pool(os.cpu_count()) as pool: tasks = [(entry, api_key_cycle) for entry in data.items()] - for result in tqdm(pool.imap(process_entry_wrapper, tasks), total=len(data)): + for result in tqdm( + pool.imap(process_entry_wrapper, tasks), total=len(data) + ): results.append(result) # Aggregate and write results in the main process diff --git a/skythought/tools/convert_to_data.py b/skythought/tools/convert_to_data.py index 91670b2..152f9a8 100644 --- a/skythought/tools/convert_to_data.py +++ b/skythought/tools/convert_to_data.py @@ -1,11 +1,15 @@ -import os -import json import argparse +import json +import os + from util.prompts import system_prompt + def main(): parser = argparse.ArgumentParser(description="Convert JSON data for processing.") - parser.add_argument("--input_dir", type=str, help="Directory containing input JSON files.") + parser.add_argument( + "--input_dir", type=str, help="Directory containing input JSON files." + ) parser.add_argument("--output", type=str, help="Output JSON file.") args = parser.parse_args() @@ -24,19 +28,24 @@ def main(): for cur_temp, cur_temp_response in response_data.items(): # Only support 0.7 for this version - assert cur_temp == "0.7", "Only support a single temperature=0.7 now." + assert ( + cur_temp == "0.7" + ), "Only support a single temperature=0.7 now." # Accept this data if cur_temp_response["correctness"]: # Create the conversation format conversations = [ {"from": "user", "value": prompt}, - {"from": "assistant", "value": cur_temp_response["processed_content"]} + { + "from": "assistant", + "value": cur_temp_response["processed_content"], + }, ] # Prepare the final structure cur_data = { "system": system_prompt, - "conversations": conversations + "conversations": conversations, } all_data.append(cur_data) @@ -44,7 +53,10 @@ def main(): with open(args.output, "w") as f: json.dump(all_data, f, indent=4) - print(f"Conversion completed. The data has been saved to {args.output} with {len(all_data)} data.") + print( + f"Conversion completed. The data has been saved to {args.output} with {len(all_data)} data." + ) + if __name__ == "__main__": main() diff --git a/skythought/tools/eval.py b/skythought/tools/eval.py index 4a729cc..b577553 100644 --- a/skythought/tools/eval.py +++ b/skythought/tools/eval.py @@ -1,32 +1,53 @@ import argparse -import subprocess -import os import json +import subprocess # Define eval to split mapping eval_to_split = { - "MATH500": "test", - "AIME": "train", - "GPQADiamond": "train", - "MMLU": "test", - "MMLUPro": "test", - "LiveCodeBench": "test", - "GSM8K": "test", - "ARC-C": "test", - "AMC23": "train", + "MATH500": "test", + "AIME": "train", + "GPQADiamond": "train", + "MMLU": "test", + "MMLUPro": "test", + "LiveCodeBench": "test", + "GSM8K": "test", + "ARC-C": "test", + "AMC23": "train", } + def parse_arguments(): - parser = argparse.ArgumentParser(description="Process model path, prompt format, and evals to run.") + parser = argparse.ArgumentParser( + description="Process model path, prompt format, and evals to run." + ) parser.add_argument("--model", required=True, type=str, help="Path to the model.") - parser.add_argument("--evals", required=True, type=str, help="Comma-separated list of evals to run (no spaces).") + parser.add_argument( + "--evals", + required=True, + type=str, + help="Comma-separated list of evals to run (no spaces).", + ) parser.add_argument("--tp", type=int, default=8, help="Tensor Parallelism Degree") - parser.add_argument("--filter-difficulty", action="store_true", help="Filter difficulty.") + parser.add_argument( + "--filter-difficulty", action="store_true", help="Filter difficulty." + ) parser.add_argument("--source", type=str, help="Source for the dataset.") - parser.add_argument("--output_file", required=True, type=str, help="Output file to write results to.") - parser.add_argument("--temperatures", type=float, nargs="+", default=[0], help="Temperature for sampling.") + parser.add_argument( + "--output_file", + required=True, + type=str, + help="Output file to write results to.", + ) + parser.add_argument( + "--temperatures", + type=float, + nargs="+", + default=[0], + help="Temperature for sampling.", + ) return parser.parse_args() + def extract_accuracy_from_output(output): # Iterate through all lines from the end to the beginning lines = output.splitlines()[::-1] @@ -37,9 +58,10 @@ def extract_accuracy_from_output(output): if "acc" in data: return data["acc"] except json.JSONDecodeError: - continue + continue return None + def write_logs_to_file(logs, output_file): try: with open(output_file, "w") as file: @@ -48,6 +70,7 @@ def write_logs_to_file(logs, output_file): except IOError as e: print(f"Failed to write logs to file {output_file}: {e}") + def main(): args = parse_arguments() @@ -60,22 +83,27 @@ def main(): script_path = "inference_and_check.py" - # Hold all logs + # Hold all logs all_logs = "" results = {} - + # Run the Python command for each eval and collect logs for eval_name in evals: command = [ - "python", script_path, - "--model", model_path, - "--dataset", eval_name, - "--split", eval_to_split[eval_name], - "--tp", str(tp), - "--temperatures" + "python", + script_path, + "--model", + model_path, + "--dataset", + eval_name, + "--split", + eval_to_split[eval_name], + "--tp", + str(tp), + "--temperatures", ] command.extend(temperatures) # Add temperatures as separate arguments - + if args.filter_difficulty: assert args.source != "", "No source passed for filtering difficulty." command.append("--filter-difficulty") @@ -84,7 +112,9 @@ def main(): print(f"Running eval {eval_name} with command {command}") all_logs += f"\nRunning eval: {eval_name} with command {command}\n" try: - with subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) as proc: + with subprocess.Popen( + command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True + ) as proc: output_lines = [] for line in proc.stdout: print(line, end="") # Stream output to the console @@ -110,5 +140,6 @@ def main(): print("Results:") print(results) + if __name__ == "__main__": main() diff --git a/skythought/tools/format.sh b/skythought/tools/format.sh new file mode 100644 index 0000000..16cbd4e --- /dev/null +++ b/skythought/tools/format.sh @@ -0,0 +1,14 @@ +set -e + +# Get tools directory path relative to git root +TOOLS_DIR=$(git rev-parse --show-toplevel)/skythought/tools + +if command -v uv >/dev/null 2>&1; then + uv pip install -q pre-commit +else + pip install -q pre-commit +fi + +git config --local core.hooksPath "$TOOLS_DIR/.githooks" +# pre-commit run --all-files always runs from the root directory. we run this only on tools/ for now. +pre-commit run --files $TOOLS_DIR/* \ No newline at end of file diff --git a/skythought/tools/inference_and_check.py b/skythought/tools/inference_and_check.py index 587e590..4deb736 100644 --- a/skythought/tools/inference_and_check.py +++ b/skythought/tools/inference_and_check.py @@ -1,35 +1,40 @@ -import json import argparse -import re -from concurrent.futures import ProcessPoolExecutor, as_completed -from vllm import LLM, SamplingParams -from tqdm import tqdm -from util.task_handlers import * -from util.model_utils import * -from openai import OpenAI import concurrent.futures +import json +import os +from concurrent.futures import ProcessPoolExecutor, as_completed from functools import partial +import numpy as np +from openai import OpenAI +from tqdm import tqdm +from vllm import LLM, SamplingParams + +from util.model_utils import MODEL_TO_NAME, SYSTEM_PROMPT +from util.task_handlers import TASK_HANDLERS, NUMINATaskHandler, TaskHandler + + class NumpyEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, np.ndarray): return obj.tolist() return super().default(obj) + def fetch_response_openai(llm, model_name, max_tokens, temp, prompt): model_name = model_name.replace("openai/", "") if "o1" in model_name: # O1 doesn't support system prompt - # NOTE: might want to implement this inside handler instead + # NOTE: might want to implement this inside handler instead for p in prompt: p["role"] = "user" - + response = llm.chat.completions.create( model=model_name, messages=prompt, n=1, - temperature=1, # has to be 1 - max_completion_tokens=max_tokens + temperature=1, # has to be 1 + max_completion_tokens=max_tokens, ) else: response = llm.chat.completions.create( @@ -37,30 +42,50 @@ def fetch_response_openai(llm, model_name, max_tokens, temp, prompt): messages=prompt, n=1, temperature=temp, - max_tokens=max_tokens + max_tokens=max_tokens, ) return response -def perform_inference_and_check(handler: TaskHandler, temperatures, max_tokens, result_file, llm, system_prompt, args): + +def perform_inference_and_check( + handler: TaskHandler, + temperatures, + max_tokens, + result_file, + llm, + system_prompt, + args, +): results = handler.load_existing_results(result_file) print(f"Loaded {len(results)} existing results.") - train_data = handler.load_and_filter_dataset(args.start, args.end, split=args.split, source=args.source, \ - filter_difficulty=args.filter_difficulty, args=args) + train_data = handler.load_and_filter_dataset( + args.start, + args.end, + split=args.split, + source=args.source, + filter_difficulty=args.filter_difficulty, + args=args, + ) remaining_data = handler.process_remaining_data(train_data, results) - conversations = handler.make_conversations(remaining_data, system_prompt, args.model) + conversations = handler.make_conversations( + remaining_data, system_prompt, args.model + ) for temp in temperatures: - if args.model.startswith("openai"): - fetch_partial = partial(fetch_response_openai, llm, args.model, max_tokens, temp) + fetch_partial = partial( + fetch_response_openai, llm, args.model, max_tokens, temp + ) with concurrent.futures.ThreadPoolExecutor(max_workers=16) as e: responses = list(e.map(fetch_partial, conversations)) else: sampling_params = SamplingParams(max_tokens=max_tokens, temperature=temp) - responses = llm.chat(messages=conversations, sampling_params=sampling_params, use_tqdm=True) - - total_correct = 0 + responses = llm.chat( + messages=conversations, sampling_params=sampling_params, use_tqdm=True + ) + + total_correct = 0 total_finish = 0 with ProcessPoolExecutor(max_workers=32) as executor: # future_to_task = { @@ -74,18 +99,26 @@ def perform_inference_and_check(handler: TaskHandler, temperatures, max_tokens, response_str = response.choices[0].message.content.strip() else: response_str = response.outputs[0].text.strip() - future_to_task[executor.submit(handler.update_results, remaining_data[idx], response_str)] = idx + future_to_task[ + executor.submit( + handler.update_results, remaining_data[idx], response_str + ) + ] = idx # print(f"Request output: {response}") - + if args.model.startswith("openai"): token_usages[idx] = response.usage else: token_usages[idx] = { "completion_tokens": len(response.outputs[0].token_ids), - "prompt_tokens": len(response.prompt_token_ids) + "prompt_tokens": len(response.prompt_token_ids), } - for future in tqdm(as_completed(future_to_task), total=len(future_to_task), desc="Processing Generations"): + for future in tqdm( + as_completed(future_to_task), + total=len(future_to_task), + desc="Processing Generations", + ): idx = future_to_task[future] response_entry = future.result() total_correct += response_entry["correctness"] @@ -102,7 +135,7 @@ def perform_inference_and_check(handler: TaskHandler, temperatures, max_tokens, results[problem_key]["prompt"] = prompt results[problem_key]["responses"][str(temp)] = response_entry - + if args.model.startswith("openai"): results[problem_key]["token_usages"][str(temp)] = { "completion_tokens": token_usages[idx].completion_tokens, @@ -110,19 +143,24 @@ def perform_inference_and_check(handler: TaskHandler, temperatures, max_tokens, } else: # TODO: vLLM model, can it do the same thing - results[problem_key]["token_usages"][str(temp)] = token_usages[idx] - + results[problem_key]["token_usages"][str(temp)] = token_usages[idx] + print(f"Final acc: {total_correct}/{total_finish}") acc = round(total_correct / total_finish, 4) if total_finish > 0 else 0 print(json.dumps({"acc": acc})) completion_tokens = [ - results[key].get("token_usages", {}).get(str(temp), {}).get("completion_tokens", 0) - for key in results for temp in temperatures + results[key] + .get("token_usages", {}) + .get(str(temp), {}) + .get("completion_tokens", 0) + for key in results + for temp in temperatures ] prompt_tokens = [ results[key].get("token_usages", {}).get(str(temp), {}).get("prompt_tokens", 0) - for key in results for temp in temperatures + for key in results + for temp in temperatures ] # Token usage summary @@ -137,8 +175,14 @@ def perform_inference_and_check(handler: TaskHandler, temperatures, max_tokens, token_dict = { "completion_tokens": sum(completion_tokens), "prompt_tokens": sum(prompt_tokens), - "avg_completion_tokens": round(sum(completion_tokens) / len(completion_tokens), 3) if completion_tokens else 0, - "avg_prompt_tokens": round(sum(prompt_tokens) / len(prompt_tokens), 3) if prompt_tokens else 0, + "avg_completion_tokens": round( + sum(completion_tokens) / len(completion_tokens), 3 + ) + if completion_tokens + else 0, + "avg_prompt_tokens": round(sum(prompt_tokens) / len(prompt_tokens), 3) + if prompt_tokens + else 0, } # Save the token usage dictionary to the result file @@ -146,16 +190,23 @@ def perform_inference_and_check(handler: TaskHandler, temperatures, max_tokens, json.dump(token_dict, f, indent=4) print(f"Token usage saved to {token_usage_result_file}") - - with open(result_file, 'w', encoding='utf-8') as file: + + with open(result_file, "w", encoding="utf-8") as file: json.dump(results, file, ensure_ascii=False, indent=4, cls=NumpyEncoder) + def perform_check(handler: TaskHandler, temperatures, result_file, args): results = handler.load_existing_results(result_file) print(f"Loaded {len(results)} existing results.") - train_data = handler.load_and_filter_dataset(args.start, args.end, split=args.split, source=args.source, \ - filter_difficulty=args.filter_difficulty, args=args) + train_data = handler.load_and_filter_dataset( + args.start, + args.end, + split=args.split, + source=args.source, + filter_difficulty=args.filter_difficulty, + args=args, + ) remaining_data = handler.process_remaining_data(train_data, {}) tasks = [] @@ -167,36 +218,58 @@ def perform_check(handler: TaskHandler, temperatures, result_file, args): if str(temp) in results[problem_key]["responses"]: response_entries = results[problem_key]["responses"][str(temp)] for sample_id, response_entry in enumerate(response_entries): - if sample_id > (args.n - 1): continue + if sample_id > (args.n - 1): + continue if True or response_entry["correctness"] is None: processed = "processed_content" in response_entry - tasks.append((item, temp, response_entry["processed_content"] if processed else response_entry["content"], sample_id)) + tasks.append( + ( + item, + temp, + response_entry["processed_content"] + if processed + else response_entry["content"], + sample_id, + ) + ) print(f"Found {len(tasks)} responses requiring reject sampling...") total_correct = 0 total_finish = 0 - correct = { temp: {} for temp in temperatures } + correct = {temp: {} for temp in temperatures} with ProcessPoolExecutor(max_workers=32) as executor: future_to_task = { - executor.submit(handler.update_results, item, content): (item, temp, sample_id) + executor.submit(handler.update_results, item, content): ( + item, + temp, + sample_id, + ) for (item, temp, content, sample_id) in tasks } # 4. Collect the results as they finish. - for future in tqdm(as_completed(future_to_task), total=len(future_to_task), desc="Processing Reject Sampling"): + for future in tqdm( + as_completed(future_to_task), + total=len(future_to_task), + desc="Processing Reject Sampling", + ): item, temp, sample_id = future_to_task[future] new_response_entry = future.result() total_correct += new_response_entry["correctness"] total_finish += 1 - + # Update the corresponding record in results problem_key = item[handler.get_question_key()] if problem_key not in correct[temp]: correct[temp][problem_key] = False if new_response_entry["correctness"]: correct[temp][problem_key] = True - assert problem_key in results and "responses" in results[problem_key] and str(temp) in results[problem_key]["responses"] + assert ( + problem_key in results + and "responses" in results[problem_key] + and str(temp) in results[problem_key]["responses"] + ) response_entry = results[problem_key]["responses"][str(temp)][sample_id] response_entry["correctness"] = new_response_entry["correctness"] response_entry["reason"] = new_response_entry["reason"] @@ -210,28 +283,50 @@ def perform_check(handler: TaskHandler, temperatures, result_file, args): temp_acc = round(temp_correct / temp_total, 4) if temp_total > 0 else 0 print(f"Temperature {temp} acc: {temp_correct}/{temp_total} ({temp_acc})") - with open(result_file, 'w', encoding='utf-8') as file: + with open(result_file, "w", encoding="utf-8") as file: json.dump(results, file, ensure_ascii=False, indent=4, cls=NumpyEncoder) - -def perform_inference_and_save(handler: TaskHandler, temperatures, max_tokens, result_file, llm, system_prompt, args): + +def perform_inference_and_save( + handler: TaskHandler, + temperatures, + max_tokens, + result_file, + llm, + system_prompt, + args, +): results = handler.load_existing_results(result_file) print(f"Loaded {len(results)} existing results.") - train_data = handler.load_and_filter_dataset(args.start, args.end, split=args.split, source=args.source, \ - filter_difficulty=args.filter_difficulty, args=args) + train_data = handler.load_and_filter_dataset( + args.start, + args.end, + split=args.split, + source=args.source, + filter_difficulty=args.filter_difficulty, + args=args, + ) remaining_data = handler.process_remaining_data(train_data, results) - conversations = handler.make_conversations(remaining_data, system_prompt, args.model) - + conversations = handler.make_conversations( + remaining_data, system_prompt, args.model + ) + for temp in temperatures: if args.model.startswith("openai"): - fetch_partial = partial(fetch_response_openai, llm, args.model, max_tokens, temp) + fetch_partial = partial( + fetch_response_openai, llm, args.model, max_tokens, temp + ) with concurrent.futures.ThreadPoolExecutor(max_workers=16) as e: responses = list(e.map(fetch_partial, conversations)) else: - sampling_params = SamplingParams(n=args.n, max_tokens=max_tokens, temperature=temp) - responses = llm.chat(messages=conversations, sampling_params=sampling_params, use_tqdm=True) + sampling_params = SamplingParams( + n=args.n, max_tokens=max_tokens, temperature=temp + ) + responses = llm.chat( + messages=conversations, sampling_params=sampling_params, use_tqdm=True + ) completion_tokens = [] prompt_tokens = [] @@ -241,23 +336,31 @@ def perform_inference_and_save(handler: TaskHandler, temperatures, max_tokens, r completion_token = 0 for sample_idx in range(args.n): response_entry = { - "content": response.choices[0].message.content.strip() if args.model.startswith("openai") else response.outputs[sample_idx].text.strip(), + "content": response.choices[0].message.content.strip() + if args.model.startswith("openai") + else response.outputs[sample_idx].text.strip(), "correctness": None, "reason": None, } response_entries.append(response_entry) if not args.model.startswith("openai"): - token_usages.append({ - "completion_tokens": len(response.outputs[sample_idx].token_ids), - "prompt_tokens": len(response.prompt_token_ids) - }) + token_usages.append( + { + "completion_tokens": len( + response.outputs[sample_idx].token_ids + ), + "prompt_tokens": len(response.prompt_token_ids), + } + ) completion_token += len(response.outputs[sample_idx].token_ids) completion_token /= args.n prompt_token = len(response.prompt_token_ids) prompt_tokens.append(prompt_token) completion_tokens.append(completion_token) - problem_key = remaining_data[idx][handler.get_question_key()] # can you use this idx + problem_key = remaining_data[idx][ + handler.get_question_key() + ] # can you use this idx if problem_key not in results: results[problem_key] = remaining_data[idx] if isinstance(handler, NUMINATaskHandler): @@ -268,7 +371,7 @@ def perform_inference_and_save(handler: TaskHandler, temperatures, max_tokens, r results[problem_key]["prompt"] = prompt results[problem_key]["responses"][str(temp)] = response_entries - + if args.model.startswith("openai"): results[problem_key]["token_usages"][str(temp)] = { "completion_tokens": response.usage.completion_tokens, @@ -289,8 +392,14 @@ def perform_inference_and_save(handler: TaskHandler, temperatures, max_tokens, r token_dict = { "completion_tokens": sum(completion_tokens), "prompt_tokens": sum(prompt_tokens), - "avg_completion_tokens": round(sum(completion_tokens) / len(completion_tokens), 3) if completion_tokens else 0, - "avg_prompt_tokens": round(sum(prompt_tokens) / len(prompt_tokens), 3) if prompt_tokens else 0, + "avg_completion_tokens": round( + sum(completion_tokens) / len(completion_tokens), 3 + ) + if completion_tokens + else 0, + "avg_prompt_tokens": round(sum(prompt_tokens) / len(prompt_tokens), 3) + if prompt_tokens + else 0, } # Save the token usage dictionary to the result file @@ -298,33 +407,94 @@ def perform_inference_and_save(handler: TaskHandler, temperatures, max_tokens, r json.dump(token_dict, f, indent=4) print(f"Token usage saved to {token_usage_result_file}") - - with open(result_file, 'w', encoding='utf-8') as file: + + with open(result_file, "w", encoding="utf-8") as file: json.dump(results, file, ensure_ascii=False, indent=4, cls=NumpyEncoder) + def main(): - parser = argparse.ArgumentParser(description="Unified inference and checking for different datasets/tasks.") - parser.add_argument("--dataset", type=str, required=True, choices=["NUMINA", "APPS", "TACO", "MATH500", "AIME", "GPQADiamond", "MMLU", "MMLUPro", "LiveCodeBench", "GSM8K", "ARC-C", "AMC23"], help="Dataset to process.") - parser.add_argument("--model", type=str, required=True, default="Qwen/QwQ-32B-Preview", help="The model to run.") + parser = argparse.ArgumentParser( + description="Unified inference and checking for different datasets/tasks." + ) + parser.add_argument( + "--dataset", + type=str, + required=True, + choices=[ + "NUMINA", + "APPS", + "TACO", + "MATH500", + "AIME", + "GPQADiamond", + "MMLU", + "MMLUPro", + "LiveCodeBench", + "GSM8K", + "ARC-C", + "AMC23", + ], + help="Dataset to process.", + ) + parser.add_argument( + "--model", + type=str, + required=True, + default="Qwen/QwQ-32B-Preview", + help="The model to run.", + ) parser.add_argument("--tp", type=int, default=8, help="Tensor Parallelism Degree") - parser.add_argument("--max_tokens", type=int, default=32768, help="Max tokens for the model.") - parser.add_argument("--split", type=str, default="train", help="Split to use for apps (e.g., train, test).") + parser.add_argument( + "--max_tokens", type=int, default=32768, help="Max tokens for the model." + ) + parser.add_argument( + "--split", + type=str, + default="train", + help="Split to use for apps (e.g., train, test).", + ) parser.add_argument("--source", type=str, help="Source for the dataset.") parser.add_argument("--start", type=int, default=0, help="Start index.") parser.add_argument("--end", type=int, default=-1, help="End index.") - parser.add_argument("--filter-difficulty", action="store_true", help="Filter difficulty.") - parser.add_argument("--result-dir", type=str, default="./", help="Result dir to save files.") - parser.add_argument("--check", action="store_true", help="Perform evaluation checks on generated samples.") + parser.add_argument( + "--filter-difficulty", action="store_true", help="Filter difficulty." + ) + parser.add_argument( + "--result-dir", type=str, default="./", help="Result dir to save files." + ) + parser.add_argument( + "--check", + action="store_true", + help="Perform evaluation checks on generated samples.", + ) parser.add_argument("--inference", action="store_true", help="Perform inference.") - parser.add_argument("--temperatures", type=float, nargs="+", default=[0], help="Temperature for sampling.") - parser.add_argument("--math-difficulty-lower-bound", type=int, default=None, help="Lowest difficulty level for math.") - parser.add_argument("--math-difficulty-upper-bound", type=int, default=None, help="Highest difficulty level for math.") - parser.add_argument("--n", type=int, default=1, help="Number of samples generated per problem.") + parser.add_argument( + "--temperatures", + type=float, + nargs="+", + default=[0], + help="Temperature for sampling.", + ) + parser.add_argument( + "--math-difficulty-lower-bound", + type=int, + default=None, + help="Lowest difficulty level for math.", + ) + parser.add_argument( + "--math-difficulty-upper-bound", + type=int, + default=None, + help="Highest difficulty level for math.", + ) + parser.add_argument( + "--n", type=int, default=1, help="Number of samples generated per problem." + ) args = parser.parse_args() - + handler: TaskHandler = TASK_HANDLERS[args.dataset]() - temperatures = [1] if args.model.startswith("openai/o1") else args.temperatures - + temperatures = [1] if args.model.startswith("openai/o1") else args.temperatures + print(f"Temperature: {temperatures}") max_tokens = args.max_tokens if temperatures == [0] and args.n > 1: @@ -334,14 +504,26 @@ def main(): # create result dir if not exists if args.result_dir and not os.path.exists(args.result_dir): os.makedirs(args.result_dir) - if args.math_difficulty_lower_bound is not None or args.math_difficulty_upper_bound is not None: - result_file = os.path.join(args.result_dir, f"{MODEL_TO_NAME[args.model]}_{args.dataset}_{args.split}_{args.source}_{args.start}_{args.end}_{args.math_difficulty_lower_bound}_{args.math_difficulty_upper_bound}.json") + if ( + args.math_difficulty_lower_bound is not None + or args.math_difficulty_upper_bound is not None + ): + result_file = os.path.join( + args.result_dir, + f"{MODEL_TO_NAME[args.model]}_{args.dataset}_{args.split}_{args.source}_{args.start}_{args.end}_{args.math_difficulty_lower_bound}_{args.math_difficulty_upper_bound}.json", + ) else: - result_file = os.path.join(args.result_dir, f"{MODEL_TO_NAME[args.model]}_{args.dataset}_{args.split}_{args.source}_{args.start}_{args.end}.json") + result_file = os.path.join( + args.result_dir, + f"{MODEL_TO_NAME[args.model]}_{args.dataset}_{args.split}_{args.source}_{args.start}_{args.end}.json", + ) if args.check: # check if converted file exists - if args.math_difficulty_lower_bound is not None or args.math_difficulty_upper_bound is not None: + if ( + args.math_difficulty_lower_bound is not None + or args.math_difficulty_upper_bound is not None + ): converted_file = f"{args.result_dir}/converted_{MODEL_TO_NAME[args.model]}_{args.dataset}_{args.split}_{args.source}_{args.start}_{args.end}_{args.math_difficulty_lower_bound}_{args.math_difficulty_upper_bound}.json" else: converted_file = f"{args.result_dir}/converted_{MODEL_TO_NAME[args.model]}_{args.dataset}_{args.split}_{args.source}_{args.start}_{args.end}.json" @@ -350,14 +532,27 @@ def main(): perform_check(handler, temperatures, result_file, args) return elif args.inference: - llm = OpenAI() if args.model.startswith("openai") else LLM(model=args.model, tensor_parallel_size=args.tp) + llm = ( + OpenAI() + if args.model.startswith("openai") + else LLM(model=args.model, tensor_parallel_size=args.tp) + ) system_prompt = SYSTEM_PROMPT[args.model] - perform_inference_and_save(handler, temperatures, max_tokens, result_file, llm, system_prompt, args) + perform_inference_and_save( + handler, temperatures, max_tokens, result_file, llm, system_prompt, args + ) return - llm = OpenAI() if args.model.startswith("openai") else LLM(model=args.model, tensor_parallel_size=args.tp) + llm = ( + OpenAI() + if args.model.startswith("openai") + else LLM(model=args.model, tensor_parallel_size=args.tp) + ) system_prompt = SYSTEM_PROMPT[args.model] - perform_inference_and_check(handler, temperatures, max_tokens, result_file, llm, system_prompt, args) + perform_inference_and_check( + handler, temperatures, max_tokens, result_file, llm, system_prompt, args + ) + if __name__ == "__main__": main() diff --git a/skythought/tools/label_math_difficulty.py b/skythought/tools/label_math_difficulty.py index 57a0ffd..3ab5804 100644 --- a/skythought/tools/label_math_difficulty.py +++ b/skythought/tools/label_math_difficulty.py @@ -1,20 +1,24 @@ -import json import argparse -from tqdm import tqdm +import ast +import json import multiprocessing as mp -import openai -from itertools import cycle -import time import os -from datasets import load_dataset import re -import ast -from util.prompts import grading_prompt, aops_criteria +import time +from itertools import cycle + +import openai +from datasets import load_dataset +from tqdm import tqdm + +from util.prompts import aops_criteria, grading_prompt + # Function to set the OpenAI API key def set_openai_key(api_key): openai.api_key = api_key + # From FastChat def find_difficulty(judgment): one_score_pattern = re.compile("\[\[(\d+\.?\d*)\]\]") @@ -27,14 +31,15 @@ def find_difficulty(judgment): rating = ast.literal_eval(match.groups()[0]) else: rating = -1 - + return rating + # GPT API processing function with retry logic def process_content(problem, api_key): # Set the OpenAI key for this request set_openai_key(api_key) - + # GPT prompt prompt = grading_prompt.format(problem=problem, aops_criteria=aops_criteria) retries = 3 @@ -44,22 +49,26 @@ def process_content(problem, api_key): response = openai.chat.completions.create( model="gpt-4o-mini", messages=[ - {"role": "system", "content": "You are a math problem difficulty labeler."}, - {"role": "user", "content": prompt} + { + "role": "system", + "content": "You are a math problem difficulty labeler.", + }, + {"role": "user", "content": prompt}, ], max_tokens=2048, - temperature=0.7 + temperature=0.7, ) return response.choices[0].message.content except openai.RateLimitError: retries -= 1 if retries == 0: return "Error: Rate limit reached and retries exhausted." - print(f"Sleep for 5 seconds for API limit.") + print("Sleep for 5 seconds for API limit.") time.sleep(5) except Exception as e: return f"Error processing content: {e}" + def process_entry(entry, api_key_cycle): # Get the next API key from the cycle api_key = next(api_key_cycle) @@ -73,24 +82,28 @@ def process_entry(entry, api_key_cycle): entry["gpt_difficulty_parsed"] = find_difficulty(processed) return entry + # Wrapper function for multiprocessing def process_entry_wrapper(args): return process_entry(*args) + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Label difficulty") parser.add_argument("--source", type=str, help="") parser.add_argument("--start", type=int, default=0, help="") parser.add_argument("--end", type=int, default=-1, help="") - parser.add_argument("--keys", type=str, help="File containing OpenAI API keys (one per line).") + parser.add_argument( + "--keys", type=str, help="File containing OpenAI API keys (one per line)." + ) args = parser.parse_args() dataset = load_dataset("AI-MO/NuminaMath-CoT") data = ( dataset["train"] .to_pandas() - .query('source == @args.source') - .iloc[args.start:args.end] + .query("source == @args.source") + .iloc[args.start : args.end] ) data = data.to_dict(orient="records") diff --git a/skythought/tools/pyproject.toml b/skythought/tools/pyproject.toml new file mode 100644 index 0000000..88570fd --- /dev/null +++ b/skythought/tools/pyproject.toml @@ -0,0 +1,3 @@ +[tool.ruff] +lint.select = ["E", "F", "I", "ASYNC", "B"] +line-length = 300 \ No newline at end of file diff --git a/skythought/tools/response_rewrite.py b/skythought/tools/response_rewrite.py index 6dff3a6..eaa307a 100644 --- a/skythought/tools/response_rewrite.py +++ b/skythought/tools/response_rewrite.py @@ -2,387 +2,489 @@ import json import os import random -from tqdm import tqdm + from skythought.tools.util.math_parsing_util import strip_answer_string -from util.model_utils import * +from tqdm import tqdm from vllm import LLM, SamplingParams -def load_dataset(dataset_path : str): - data = {} - with open(dataset_path, 'r', encoding='utf-8') as file: - data = json.load(file) - return data +from util.model_utils import ( + SUBPROBLEM_SPLIT_PROMPT, + SUBSOLUTION_EXTRACTION_PROMPT, + SYSTEM_PROMPT, +) -def make_scoring_conversations(dataset, system_prompt): - conversations = [] - for _, key in enumerate(dataset): - problem = dataset[key] - gt_answer = strip_answer_string(problem["answer"]) - for response_key in problem["responses"]: - response = problem["responses"][response_key]["content"] - prompt_text = response + "\n#####\nThe ground truth answer is " + gt_answer - conversations.append([ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt_text} - ]) +def load_dataset(dataset_path: str): + data = {} + with open(dataset_path, "r", encoding="utf-8") as file: + data = json.load(file) + return data - return conversations + +def make_scoring_conversations(dataset, system_prompt): + conversations = [] + for _, key in enumerate(dataset): + problem = dataset[key] + gt_answer = strip_answer_string(problem["answer"]) + for response_key in problem["responses"]: + response = problem["responses"][response_key]["content"] + prompt_text = response + "\n#####\nThe ground truth answer is " + gt_answer + conversations.append( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt_text}, + ] + ) + + return conversations def score_solutions(dataset, responses, outfile): - idx = 0 - for _, key in tqdm(enumerate(dataset), total=len(dataset), desc="Scoring original solutions"): - problem = dataset[key] - for response_key in problem["responses"]: - score = responses[idx].outputs[0].text.strip() - problem["responses"][response_key]["correctness"] = (score == "True") - idx += 1 + idx = 0 + for _, key in tqdm( + enumerate(dataset), total=len(dataset), desc="Scoring original solutions" + ): + problem = dataset[key] + for response_key in problem["responses"]: + score = responses[idx].outputs[0].text.strip() + problem["responses"][response_key]["correctness"] = score == "True" + idx += 1 - with open(outfile, 'w', encoding='utf-8') as new_file: - json.dump(dataset, new_file, ensure_ascii=False, indent=2) - return dataset + with open(outfile, "w", encoding="utf-8") as new_file: + json.dump(dataset, new_file, ensure_ascii=False, indent=2) + return dataset def filter_solutions(dataset): - # First filter out incorrect responses. - for key in dataset: - problem = dataset[key] + # First filter out incorrect responses. + for key in dataset: + problem = dataset[key] + keys_to_filter = [] + for response_key in problem["responses"]: + if not problem["responses"][response_key]["correctness"]: + keys_to_filter.append(response_key) + for k in keys_to_filter: + del problem["responses"][k] + del problem["token_usages"][k] + + # Next, filter out examples with <2 correct responses. keys_to_filter = [] - for response_key in problem["responses"]: - if not problem["responses"][response_key]["correctness"]: - keys_to_filter.append(response_key) + for key in dataset: + problem = dataset[key] + if len(problem["responses"]) < 2: + keys_to_filter.append(key) for k in keys_to_filter: - del problem["responses"][k] - del problem["token_usages"][k] - - # Next, filter out examples with <2 correct responses. - keys_to_filter = [] - for key in dataset: - problem = dataset[key] - if len(problem["responses"]) < 2: - keys_to_filter.append(key) - for k in keys_to_filter: - del dataset[k] - - # Finally, filter for the shortest and longest solutions for each sample. - for key in dataset: - problem = dataset[key] - token_usages = problem["token_usages"] - shortest_key, shortest_entry = min(token_usages.items(), key=lambda x: x[1]["completion_tokens"]) - longest_key, longest_entry = max(token_usages.items(), key=lambda x: x[1]["completion_tokens"]) - problem["token_usages"] = { - "shortest": shortest_entry, - "longest": longest_entry, - } - new_responses = { - "shortest": problem["responses"][shortest_key], - "longest":problem["responses"][longest_key], - } - problem["responses"] = new_responses - - return dataset - + del dataset[k] + + # Finally, filter for the shortest and longest solutions for each sample. + for key in dataset: + problem = dataset[key] + token_usages = problem["token_usages"] + shortest_key, shortest_entry = min( + token_usages.items(), key=lambda x: x[1]["completion_tokens"] + ) + longest_key, longest_entry = max( + token_usages.items(), key=lambda x: x[1]["completion_tokens"] + ) + problem["token_usages"] = { + "shortest": shortest_entry, + "longest": longest_entry, + } + new_responses = { + "shortest": problem["responses"][shortest_key], + "longest": problem["responses"][longest_key], + } + problem["responses"] = new_responses + + return dataset + def make_splitting_conversations(data, system_prompt): - conversations = [] - for problem in data: - response = data[problem]["responses"]["shortest"] - prompt_text = response["content"] - conversations.append([ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt_text} - ]) - return conversations + conversations = [] + for problem in data: + response = data[problem]["responses"]["shortest"] + prompt_text = response["content"] + conversations.append( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt_text}, + ] + ) + return conversations def split_solutions(dataset, responses, delimiter): - outputs = [] - for _, response in tqdm(enumerate(responses), total=len(responses), desc="Splitting responses"): - content = response.outputs[0].text.strip() - # Split response by configured delimiter. - split_content = content.split(delimiter) - split_content = [x.strip() for x in split_content if x != ""] - outputs.append(split_content) - for idx, key in enumerate(dataset): - solutions = outputs[idx] - problem = dataset[key] - problem["responses"]["shortest"]["subsolutions"] = solutions - return dataset + outputs = [] + for _, response in tqdm( + enumerate(responses), total=len(responses), desc="Splitting responses" + ): + content = response.outputs[0].text.strip() + # Split response by configured delimiter. + split_content = content.split(delimiter) + split_content = [x.strip() for x in split_content if x != ""] + outputs.append(split_content) + for idx, key in enumerate(dataset): + solutions = outputs[idx] + problem = dataset[key] + problem["responses"]["shortest"]["subsolutions"] = solutions + return dataset def make_subscoring_conversations(dataset, system_prompt): - conversations = [] - for _, key in enumerate(dataset): - problem = dataset[key] - gt_answer = strip_answer_string(problem["answer"]) - subsolutions = problem["responses"]["shortest"]["subsolutions"] - for sub in subsolutions: - prompt_text = sub + "\n#####\nThe ground truth answer is " + gt_answer - conversations.append([ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt_text} - ]) - return conversations + conversations = [] + for _, key in enumerate(dataset): + problem = dataset[key] + gt_answer = strip_answer_string(problem["answer"]) + subsolutions = problem["responses"]["shortest"]["subsolutions"] + for sub in subsolutions: + prompt_text = sub + "\n#####\nThe ground truth answer is " + gt_answer + conversations.append( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt_text}, + ] + ) + return conversations def score_subsolutions(dataset, responses): - idx = 0 - for _, key in tqdm(enumerate(dataset), total=len(dataset), desc="Scoring sub-solutions"): - problem = dataset[key] - subsolutions = problem["responses"]["shortest"]["subsolutions"] - scores = [] - for _, sub in enumerate(subsolutions): - score = responses[idx].outputs[0].text.strip() - scores.append(score == "True") - idx += 1 - problem["responses"]["shortest"]["scores"] = scores - return dataset + idx = 0 + for _, key in tqdm( + enumerate(dataset), total=len(dataset), desc="Scoring sub-solutions" + ): + problem = dataset[key] + subsolutions = problem["responses"]["shortest"]["subsolutions"] + scores = [] + for _, _ in enumerate(subsolutions): + score = responses[idx].outputs[0].text.strip() + scores.append(score == "True") + idx += 1 + problem["responses"]["shortest"]["scores"] = scores + return dataset def build_response_variants(dataset): - def clean_response_string(response): - if '<|end_of_thought|>' not in response: - response += '<|end_of_thought|>' - return response - - keys_to_remove = [] - - for key, problem in dataset.items(): - scores = problem["responses"]["shortest"]["scores"] - subsolutions = problem["responses"]["shortest"]["subsolutions"] - - # Check if there are valid scores - if True not in scores: - keys_to_remove.append(key) - continue - - # Build FCS (First Correct Solution) - fcs_idx = scores.index(True) - fcs_response = "\n".join(subsolutions[:fcs_idx + 1]) if fcs_idx < len(scores) - 1 else "\n".join(subsolutions[:-1]) - fcs_response = clean_response_string(fcs_response) + "\n" + subsolutions[-1] - problem["responses"]["fcs"] = fcs_response - - # Build FCS + 1 - fcs_plus1_idx = fcs_idx + 1 if fcs_idx + 1 < len(subsolutions) - 1 else fcs_idx - fcs_plus1_response = "\n".join(subsolutions[:fcs_plus1_idx + 1]) - fcs_plus1_response = clean_response_string(fcs_plus1_response) + "\n" + subsolutions[-1] - problem["responses"]["fcs_plus1"] = fcs_plus1_response - - # Check if there are valid scores - if True not in scores[fcs_idx + 1:]: - keys_to_remove.append(key) - continue - - # Build FCS + Reflection - fcs_reflection_idx = scores.index(True, fcs_idx + 1) - fcs_reflection_response = "\n".join(subsolutions[:fcs_reflection_idx + 1]) if fcs_reflection_idx < len(scores) - 1 else "\n".join(subsolutions[:-1]) - fcs_reflection_response = clean_response_string(fcs_reflection_response) + "\n" + subsolutions[-1] - problem["responses"]["fcs_reflection"] = fcs_reflection_response - - # Remove problems without valid sub-solutions - for key in keys_to_remove: - del dataset[key] - - return dataset + def clean_response_string(response): + if "<|end_of_thought|>" not in response: + response += "<|end_of_thought|>" + return response + + keys_to_remove = [] + + for key, problem in dataset.items(): + scores = problem["responses"]["shortest"]["scores"] + subsolutions = problem["responses"]["shortest"]["subsolutions"] + + # Check if there are valid scores + if True not in scores: + keys_to_remove.append(key) + continue + + # Build FCS (First Correct Solution) + fcs_idx = scores.index(True) + fcs_response = ( + "\n".join(subsolutions[: fcs_idx + 1]) + if fcs_idx < len(scores) - 1 + else "\n".join(subsolutions[:-1]) + ) + fcs_response = clean_response_string(fcs_response) + "\n" + subsolutions[-1] + problem["responses"]["fcs"] = fcs_response + + # Build FCS + 1 + fcs_plus1_idx = fcs_idx + 1 if fcs_idx + 1 < len(subsolutions) - 1 else fcs_idx + fcs_plus1_response = "\n".join(subsolutions[: fcs_plus1_idx + 1]) + fcs_plus1_response = ( + clean_response_string(fcs_plus1_response) + "\n" + subsolutions[-1] + ) + problem["responses"]["fcs_plus1"] = fcs_plus1_response + + # Check if there are valid scores + if True not in scores[fcs_idx + 1 :]: + keys_to_remove.append(key) + continue + + # Build FCS + Reflection + fcs_reflection_idx = scores.index(True, fcs_idx + 1) + fcs_reflection_response = ( + "\n".join(subsolutions[: fcs_reflection_idx + 1]) + if fcs_reflection_idx < len(scores) - 1 + else "\n".join(subsolutions[:-1]) + ) + fcs_reflection_response = ( + clean_response_string(fcs_reflection_response) + "\n" + subsolutions[-1] + ) + problem["responses"]["fcs_reflection"] = fcs_reflection_response + + # Remove problems without valid sub-solutions + for key in keys_to_remove: + del dataset[key] + + return dataset def compute_token_usages(dataset, variants, llm): - tokenizer = llm.get_tokenizer() - for key in tqdm(dataset, desc="Computing token usages", total=len(dataset)): - problem = dataset[key] - prompt_tokens = problem["token_usages"]["shortest"]["prompt_tokens"] - for variant in variants: - problem["token_usages"][variant] = { - "prompt_tokens": prompt_tokens, - "completion_tokens": len(tokenizer(problem["responses"][variant]).input_ids) - } - return dataset + tokenizer = llm.get_tokenizer() + for key in tqdm(dataset, desc="Computing token usages", total=len(dataset)): + problem = dataset[key] + prompt_tokens = problem["token_usages"]["shortest"]["prompt_tokens"] + for variant in variants: + problem["token_usages"][variant] = { + "prompt_tokens": prompt_tokens, + "completion_tokens": len( + tokenizer(problem["responses"][variant]).input_ids + ), + } + return dataset def build_question_prompt(prompt): - return "Return your final response within \\boxed{{}}" + prompt + return "Return your final response within \\boxed{{}}" + prompt + def make_preference_conversations(final_dataset, format, system_prompt): - conversations = [] - for prompt in final_dataset: - problem = final_dataset[prompt] - convo = {} - convo["conversations"] = [ - { - "from": "system", - "value": system_prompt, - }, - { - "from": "human", - "value": build_question_prompt(prompt), - } - ] - convo["chosen"] = { - "from": "gpt", - "value": problem["responses"][format], - } - convo["rejected"] = { - "from": "gpt", - "value": problem["responses"]["longest"]["content"] - } - conversations.append(convo) - - return conversations + conversations = [] + for prompt in final_dataset: + problem = final_dataset[prompt] + convo = {} + convo["conversations"] = [ + { + "from": "system", + "value": system_prompt, + }, + { + "from": "human", + "value": build_question_prompt(prompt), + }, + ] + convo["chosen"] = { + "from": "gpt", + "value": problem["responses"][format], + } + convo["rejected"] = { + "from": "gpt", + "value": problem["responses"]["longest"]["content"], + } + conversations.append(convo) + + return conversations def make_SILC_conversations(dataset, system_prompt): - keys_to_filter = [] - for prompt in dataset: - problem = dataset[prompt] - contition = False - for response_key in problem["responses"]: - if not problem["responses"][response_key]['correctness']: - wrong_length = problem["token_usages"][response_key]['completion_tokens'] - for k in problem["responses"]: - if k != response_key and problem["token_usages"][k]['completion_tokens'] > wrong_length and problem["responses"][k]['correctness']: - contition = True - break - break - if not contition: - keys_to_filter.append(prompt) - - for key in keys_to_filter: - del dataset[key] - - # Build contrastive pairs out of {short incorrect, long correct} - conversations = [] - for prompt in dataset: - problem = dataset[prompt] - - shortest_incorrect_key = None - shortest_incorrect_length = float('inf') - - # Get shortest incorrect. - for response_key in problem["responses"]: - if not problem["responses"][response_key]['correctness']: - length = problem["token_usages"][response_key]['completion_tokens'] - if length < shortest_incorrect_length: - shortest_incorrect_length = length - shortest_incorrect_key = response_key - - # Get next longest correct. - shortest_correct_longer_key = None - shortest_correct_longer_length = float('inf') - for response_key in problem["responses"]: - if problem["responses"][response_key]['correctness']: - length = problem["token_usages"][response_key]['completion_tokens'] - if length > shortest_incorrect_length and length < shortest_correct_longer_length: - shortest_correct_longer_length = length - shortest_correct_longer_key = response_key - - convo = {} - convo["conversations"] = [ - { - "from": "system", - "value": system_prompt, - }, - { - "from": "human", - "value": build_question_prompt(prompt), - } - ] - convo["chosen"] = { - "from": "gpt", - "value": problem["responses"][shortest_correct_longer_key]['content'], - } - convo["rejected"] = { - "from": "gpt", - "value": problem["responses"][shortest_incorrect_key]["content"] - } - conversations.append(convo) - - return conversations + keys_to_filter = [] + for prompt in dataset: + problem = dataset[prompt] + contition = False + for response_key in problem["responses"]: + if not problem["responses"][response_key]["correctness"]: + wrong_length = problem["token_usages"][response_key][ + "completion_tokens" + ] + for k in problem["responses"]: + if ( + k != response_key + and problem["token_usages"][k]["completion_tokens"] + > wrong_length + and problem["responses"][k]["correctness"] + ): + contition = True + break + break + if not contition: + keys_to_filter.append(prompt) + + for key in keys_to_filter: + del dataset[key] + + # Build contrastive pairs out of {short incorrect, long correct} + conversations = [] + for prompt in dataset: + problem = dataset[prompt] + + shortest_incorrect_key = None + shortest_incorrect_length = float("inf") + + # Get shortest incorrect. + for response_key in problem["responses"]: + if not problem["responses"][response_key]["correctness"]: + length = problem["token_usages"][response_key]["completion_tokens"] + if length < shortest_incorrect_length: + shortest_incorrect_length = length + shortest_incorrect_key = response_key + + # Get next longest correct. + shortest_correct_longer_key = None + shortest_correct_longer_length = float("inf") + for response_key in problem["responses"]: + if problem["responses"][response_key]["correctness"]: + length = problem["token_usages"][response_key]["completion_tokens"] + if ( + length > shortest_incorrect_length + and length < shortest_correct_longer_length + ): + shortest_correct_longer_length = length + shortest_correct_longer_key = response_key + + convo = {} + convo["conversations"] = [ + { + "from": "system", + "value": system_prompt, + }, + { + "from": "human", + "value": build_question_prompt(prompt), + }, + ] + convo["chosen"] = { + "from": "gpt", + "value": problem["responses"][shortest_correct_longer_key]["content"], + } + convo["rejected"] = { + "from": "gpt", + "value": problem["responses"][shortest_incorrect_key]["content"], + } + conversations.append(convo) + + return conversations def main(): - parser = argparse.ArgumentParser(description="Filter, rewrite, and format generated responses for high-quality data curation.") - parser.add_argument("--rewrite-model", type=str, required=True, default="meta-llama/Llama-3.3-70B-Instruct", help="The model used for response processing.") - parser.add_argument("--target-model", type=str, required=True, default="NovaSky-AI/Sky-T1-32B-Preview", help="The target model the rewritten responses will be used to train.") - parser.add_argument("--dataset", type=str, required=True, help="Path to the starting dataset of generated responses to filter from.") - parser.add_argument("--result-dir", type=str, default="./", help="Result directory to save processed data.") - parser.add_argument("--checkpoint", action="store_true", help="Whether to checkpoint the dataset at each step.") - parser.add_argument("--SILC", action="store_true", help="Whether to include short-incorrect/long-correct (SILC) preference pairs.") - parser.add_argument("--tp", type=int, default=8, help="Tensor Parallelism Degree") - parser.add_argument("--max_tokens", type=int, default=32768, help="Max tokens for the model.") - args = parser.parse_args() - - if args.result_dir and not os.path.exists(args.result_dir): + parser = argparse.ArgumentParser( + description="Filter, rewrite, and format generated responses for high-quality data curation." + ) + parser.add_argument( + "--rewrite-model", + type=str, + required=True, + default="meta-llama/Llama-3.3-70B-Instruct", + help="The model used for response processing.", + ) + parser.add_argument( + "--target-model", + type=str, + required=True, + default="NovaSky-AI/Sky-T1-32B-Preview", + help="The target model the rewritten responses will be used to train.", + ) + parser.add_argument( + "--dataset", + type=str, + required=True, + help="Path to the starting dataset of generated responses to filter from.", + ) + parser.add_argument( + "--result-dir", + type=str, + default="./", + help="Result directory to save processed data.", + ) + parser.add_argument( + "--checkpoint", + action="store_true", + help="Whether to checkpoint the dataset at each step.", + ) + parser.add_argument( + "--SILC", + action="store_true", + help="Whether to include short-incorrect/long-correct (SILC) preference pairs.", + ) + parser.add_argument("--tp", type=int, default=8, help="Tensor Parallelism Degree") + parser.add_argument( + "--max_tokens", type=int, default=32768, help="Max tokens for the model." + ) + args = parser.parse_args() + + if args.result_dir and not os.path.exists(args.result_dir): os.makedirs(args.result_dir) - # Initialize model for data processing. - llm = LLM(model=args.rewrite_model, tensor_parallel_size=args.tp) - sampling_params = SamplingParams(max_tokens=args.max_tokens) - - original_dataset = load_dataset(args.dataset) - - # Filter for the shortest and longest correct solutions. - filtered_dataset = filter_solutions(original_dataset) - if args.checkpoint: - outfile = os.path.join(args.result_dir, f"filtered-responses.json") - with open(outfile, 'w', encoding='utf-8') as new_file: - json.dump(filtered_dataset, new_file, ensure_ascii=False, indent=2) - - # Split the shortest solution into subsolutions using the configured model. - conversations = make_splitting_conversations(filtered_dataset, SUBPROBLEM_SPLIT_PROMPT) - responses = llm.chat(messages=conversations, sampling_params=sampling_params, use_tqdm=True) - split_dataset = split_solutions(filtered_dataset, responses, '#####') - if args.checkpoint: - outfile = os.path.join(args.result_dir, f"split-solutions.json") - with open(outfile, 'w', encoding='utf-8') as new_file: - json.dump(split_dataset, new_file, ensure_ascii=False, indent=2) - - # Score the subsolutions using the configured model. - subscoring_conversations = make_subscoring_conversations(split_dataset, SUBSOLUTION_EXTRACTION_PROMPT) - responses = llm.chat(messages=subscoring_conversations, sampling_params=sampling_params, use_tqdm=True) - scored_dataset = score_subsolutions(split_dataset, responses) - if args.checkpoint: - outfile = os.path.join(args.result_dir, f"scored-subsolutions.json") - with open(outfile, 'w', encoding='utf-8') as new_file: - json.dump(scored_dataset, new_file, ensure_ascii=False, indent=2) - - # Rewrite response based on variants of combining sub-solutions. Here are examples for - # FCS, FCS+1, and FCS+Reflection. - variants_dataset = build_response_variants(scored_dataset) - if args.checkpoint: - outfile = os.path.join(args.result_dir, f"response-variants.json") - with open(outfile, 'w', encoding='utf-8') as new_file: - json.dump(variants_dataset, new_file, ensure_ascii=False, indent=2) - - # Add per-variant token counts to dataset for convenience. - final_dataset = compute_token_usages(variants_dataset, ["fcs", "fcs_plus1", "fcs_reflection"], llm) - - system_prompt = SYSTEM_PROMPT[args.target_model] - - # Generate conversation format for each variant, which can be used in SimPO/DPO/etc. - fcs_convo = make_preference_conversations(final_dataset, "fcs", system_prompt) - fcs_plus1_convo = make_preference_conversations(final_dataset, "fcs_plus1", system_prompt) - fcs_reflection_convo = make_preference_conversations(final_dataset, "fcs_reflection", system_prompt) - - # Optionall add short incorrect, long correct (SILC) conversations - if args.SILC: - short_incorrect_long_correct_conversations = make_SILC_conversations(load_dataset(args.dataset), system_prompt) - for convo in [fcs_convo, fcs_plus1_convo, fcs_reflection_convo]: - convo += short_incorrect_long_correct_conversations - random.shuffle(convo) - - # Save final conversation variants. - fcs_outfile = os.path.join(args.result_dir, "fcs-conversations.json") - with open(fcs_outfile, 'w', encoding='utf-8') as new_file: - json.dump(fcs_convo, new_file, ensure_ascii=False, indent=2) - - fcs_plus1_outfile = os.path.join(args.result_dir, "fcs_plus1-conversations.json") - with open(fcs_plus1_outfile, 'w', encoding='utf-8') as new_file: - json.dump(fcs_plus1_convo, new_file, ensure_ascii=False, indent=2) - - fcs_reflection_outfile = os.path.join(args.result_dir, "fcs_reflection-conversations.json") - with open(fcs_reflection_outfile, 'w', encoding='utf-8') as new_file: - json.dump(fcs_reflection_convo, new_file, ensure_ascii=False, indent=2) + # Initialize model for data processing. + llm = LLM(model=args.rewrite_model, tensor_parallel_size=args.tp) + sampling_params = SamplingParams(max_tokens=args.max_tokens) + + original_dataset = load_dataset(args.dataset) + + # Filter for the shortest and longest correct solutions. + filtered_dataset = filter_solutions(original_dataset) + if args.checkpoint: + outfile = os.path.join(args.result_dir, "filtered-responses.json") + with open(outfile, "w", encoding="utf-8") as new_file: + json.dump(filtered_dataset, new_file, ensure_ascii=False, indent=2) + + # Split the shortest solution into subsolutions using the configured model. + conversations = make_splitting_conversations( + filtered_dataset, SUBPROBLEM_SPLIT_PROMPT + ) + responses = llm.chat( + messages=conversations, sampling_params=sampling_params, use_tqdm=True + ) + split_dataset = split_solutions(filtered_dataset, responses, "#####") + if args.checkpoint: + outfile = os.path.join(args.result_dir, "split-solutions.json") + with open(outfile, "w", encoding="utf-8") as new_file: + json.dump(split_dataset, new_file, ensure_ascii=False, indent=2) + + # Score the subsolutions using the configured model. + subscoring_conversations = make_subscoring_conversations( + split_dataset, SUBSOLUTION_EXTRACTION_PROMPT + ) + responses = llm.chat( + messages=subscoring_conversations, + sampling_params=sampling_params, + use_tqdm=True, + ) + scored_dataset = score_subsolutions(split_dataset, responses) + if args.checkpoint: + outfile = os.path.join(args.result_dir, "scored-subsolutions.json") + with open(outfile, "w", encoding="utf-8") as new_file: + json.dump(scored_dataset, new_file, ensure_ascii=False, indent=2) + + # Rewrite response based on variants of combining sub-solutions. Here are examples for + # FCS, FCS+1, and FCS+Reflection. + variants_dataset = build_response_variants(scored_dataset) + if args.checkpoint: + outfile = os.path.join(args.result_dir, "response-variants.json") + with open(outfile, "w", encoding="utf-8") as new_file: + json.dump(variants_dataset, new_file, ensure_ascii=False, indent=2) + + # Add per-variant token counts to dataset for convenience. + final_dataset = compute_token_usages( + variants_dataset, ["fcs", "fcs_plus1", "fcs_reflection"], llm + ) + + system_prompt = SYSTEM_PROMPT[args.target_model] + + # Generate conversation format for each variant, which can be used in SimPO/DPO/etc. + fcs_convo = make_preference_conversations(final_dataset, "fcs", system_prompt) + fcs_plus1_convo = make_preference_conversations( + final_dataset, "fcs_plus1", system_prompt + ) + fcs_reflection_convo = make_preference_conversations( + final_dataset, "fcs_reflection", system_prompt + ) + + # Optionall add short incorrect, long correct (SILC) conversations + if args.SILC: + short_incorrect_long_correct_conversations = make_SILC_conversations( + load_dataset(args.dataset), system_prompt + ) + for convo in [fcs_convo, fcs_plus1_convo, fcs_reflection_convo]: + convo += short_incorrect_long_correct_conversations + random.shuffle(convo) + + # Save final conversation variants. + fcs_outfile = os.path.join(args.result_dir, "fcs-conversations.json") + with open(fcs_outfile, "w", encoding="utf-8") as new_file: + json.dump(fcs_convo, new_file, ensure_ascii=False, indent=2) + + fcs_plus1_outfile = os.path.join(args.result_dir, "fcs_plus1-conversations.json") + with open(fcs_plus1_outfile, "w", encoding="utf-8") as new_file: + json.dump(fcs_plus1_convo, new_file, ensure_ascii=False, indent=2) + + fcs_reflection_outfile = os.path.join( + args.result_dir, "fcs_reflection-conversations.json" + ) + with open(fcs_reflection_outfile, "w", encoding="utf-8") as new_file: + json.dump(fcs_reflection_convo, new_file, ensure_ascii=False, indent=2) if __name__ == "__main__": diff --git a/skythought/tools/tasks/aime/aime_handler.py b/skythought/tools/tasks/aime/aime_handler.py index e69de29..b8c54dd 100644 --- a/skythought/tools/tasks/aime/aime_handler.py +++ b/skythought/tools/tasks/aime/aime_handler.py @@ -0,0 +1,40 @@ +from datasets import load_dataset + +from tasks.common import MathTaskHandler +from util.model_utils import MODEL_TO_NAME + + +class AIMETaskHandler(MathTaskHandler): + def __init__(self): + self.dataset = "AI-MO/aimo-validation-aime" + + @staticmethod + def generate_prompt(prompt, model): + if MODEL_TO_NAME[model] == "Sky-T1-32B-Preview": + return prompt + "\nReturn your final response within \\boxed{{}}" + else: + return "Return your final response within \\boxed{{}}. " + prompt + + @staticmethod + def get_question_key(): + return "problem" + + def make_conversations(self, data, system_prompt, model=None): + conversations = [] + for problem in data: + prompt_text = self.generate_prompt(problem["problem"], model) + conversations.append( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt_text}, + ] + ) + return conversations + + def load_and_filter_dataset( + self, start, end, split="train", source=None, filter_difficulty=False, args=None + ): + dataset = load_dataset(self.dataset) + train_data = dataset[split].to_pandas() + filtered_data = train_data[train_data["url"].str.contains("2024", na=False)] + return filtered_data.iloc[start:end] if end > 0 else filtered_data.iloc[start:] diff --git a/skythought/tools/tasks/math500/math500.yaml b/skythought/tools/tasks/amc23/amc23.yaml similarity index 100% rename from skythought/tools/tasks/math500/math500.yaml rename to skythought/tools/tasks/amc23/amc23.yaml diff --git a/skythought/tools/tasks/amc23/amc23_handler.py b/skythought/tools/tasks/amc23/amc23_handler.py new file mode 100644 index 0000000..fbacdb2 --- /dev/null +++ b/skythought/tools/tasks/amc23/amc23_handler.py @@ -0,0 +1,24 @@ +from datasets import load_dataset +from typing import Dict, Any +from multiprocessing import Manager +from tasks.apps.apps_util import run_test as apps_run_test +from tasks.taco.taco_util import run_test as taco_run_test +from util.math_parsing_util import strip_answer_string, get_multiple_choice_answer, extract_answer, math_equal, mmlu_pro_extract_answer +from tasks.livecodebench.livecodebench_util import unsafe_lcb_runTests, map_to_example, has_test_type, post_process_code, translate_private_test_cases +from util.common import TimeoutException, timeout +from util.model_utils import SYSTEM_PROMPT, MODEL_TO_NAME +from tasks.common import MathTaskHandler + +class AMC23TaskHandler(MathTaskHandler): + def __init__(self): + self.dataset = "AI-MO/aimo-validation-amc" + + @staticmethod + def get_question_key(): + return "problem" + + def load_and_filter_dataset(self, start, end, split="train", source=None, filter_difficulty=False, args=None): + dataset = load_dataset(self.dataset) + train_data = dataset[split].to_pandas() + filtered_data = train_data[train_data['url'].str.contains("2023", na=False)] + return filtered_data.iloc[start:end] if end > 0 else filtered_data.iloc[start:] \ No newline at end of file diff --git a/skythought/tools/tasks/apps/apps_handler.py b/skythought/tools/tasks/apps/apps_handler.py index e69de29..603191c 100644 --- a/skythought/tools/tasks/apps/apps_handler.py +++ b/skythought/tools/tasks/apps/apps_handler.py @@ -0,0 +1,128 @@ +import copy +import json +import multiprocessing +from multiprocessing import Manager + +import numpy as np +from datasets import load_dataset + +from tasks.apps.apps_util import run_test as apps_run_test +from util.common import has_code + +from ..common import TaskHandler + + +class APPSTaskHandler(TaskHandler): + @staticmethod + def get_question_key(): + return "question" + + @staticmethod + def generate_prompt(test_case, prompt, starter_code=None): + _input = "" + data = test_case + if not data.get("fn_name"): + _input += "Generate an executable Python function generated from the given prompt. The function should take stdin as input and print the output. Simply call the function after the definition." # "\nUse Standard Input format"#\n" + else: + _input += "Generate an executable Python function generated from the given prompt. Return the function body without invoking it at the final solution." # "\nUse Call-Based format"#\n" + data = prompt + _input += data + if starter_code != None: + data = starter_code + data = "\n" + data # + "\n" + _input += data + else: + # _input += "\n\n" + pass + + return _input + + def check_correctness(self, problem, generation): + TIMEOUT = 10 + + def _temp_run(problem, generation, debug, result): + try: + result.append( + apps_run_test(problem=problem, test=generation, debug=debug) + ) + except Exception: + pass + + manager = Manager() + result = manager.list() + p = multiprocessing.Process( + target=_temp_run, args=(problem, generation, False, result) + ) + p.start() + p.join(timeout=TIMEOUT + 1) + if p.is_alive(): + p.kill() + return bool(result and np.all(result[0])) + + def update_results(self, problem, response): + if not isinstance(response, str): + response = response.outputs[0].text.strip() + # Initialize the response structure + response_entry = { + "content": response, + "correctness": None, + "reason": None, + } + code_filter_result = has_code(response) + if len(code_filter_result) == 0: + response_entry["correctness"] = False + response_entry["reason"] = "Does not contain code component." + else: + last_code = code_filter_result[-1] + problem_to_check = copy.deepcopy(problem) + problem_to_check["input_output"] = json.loads(problem["input_output"]) + try: + problem_to_check["solutions"] = json.loads(problem["solutions"]) + except: + problem_to_check["solutions"] = "" + print("Empty solution from the dataset") + curr_res = self.check_correctness(problem_to_check, generation=last_code) + if curr_res: + response_entry["correctness"] = True + response_entry["reason"] = "" + else: + response_entry["correctness"] = False + response_entry["reason"] = "Code is incorrect." + + return response_entry + + def make_conversations(self, data, system_prompt, model=None): + conversations = [] + for problem in data: + test_case = json.loads(problem["input_output"]) + starter_code = problem["starter_code"] + prompt_text = self.generate_prompt( + test_case, problem["question"], starter_code + ) + conversations.append( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt_text}, + ] + ) + return conversations + + def load_and_filter_dataset( + self, start, end, split="train", source=None, filter_difficulty=False, args=None + ): + dataset = load_dataset("codeparrot/apps", trust_remote_code=True) + train_data = dataset[split].to_pandas() + if not filter_difficulty: + return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] + return ( + train_data.query("difficulty == @source").iloc[start:end] + if end > 0 + else train_data.query("difficulty == @source").iloc[start:] + ) + + def process_remaining_data(self, train_data, results): + return [ + row.to_dict() + for _, row in train_data.iterrows() + if str(row["question"]) not in results + ] diff --git a/skythought/tools/tasks/apps/apps_util.py b/skythought/tools/tasks/apps/apps_util.py index 7dcb87e..c98dd1f 100644 --- a/skythought/tools/tasks/apps/apps_util.py +++ b/skythought/tools/tasks/apps/apps_util.py @@ -1,43 +1,49 @@ # From APPS import argparse +import faulthandler import json import os -import sys -import io -import faulthandler import platform -# used for debugging to time steps -from datetime import datetime - # to run the solution files we're using a timing based approach import signal +import sys + +# used for debugging to time steps +from datetime import datetime +from enum import Enum -import numpy as np # for capturing the stdout from io import StringIO -from typing import get_type_hints -from typing import List, Tuple +from typing import List + # used for testing the code that reads from input -from unittest.mock import patch, mock_open +from unittest.mock import mock_open, patch +import numpy as np from pyext import RuntimeModule -from enum import Enum + class CODE_TYPE(Enum): call_based = 0 standard_input = 1 + # stuff for setting up signal timer class TimeoutException(Exception): pass + + def timeout_handler(signum, frame): print("alarm went off") - #return + # return raise TimeoutException + + signal.signal(signal.SIGALRM, timeout_handler) timeout = 4 # seconds + # used to capture stdout as a list # from https://stackoverflow.com/a/16571630/6416660 # alternative use redirect_stdout() from contextlib @@ -48,24 +54,39 @@ def __enter__(self): # Make closing the StringIO a no-op self._stringio.close = lambda x: 1 return self + def __exit__(self, *args): self.extend(self._stringio.getvalue().splitlines()) - del self._stringio # free up some memory + del self._stringio # free up some memory sys.stdout = self._stdout def parse_args(): parser = argparse.ArgumentParser(description="Utility for testing code generation.") - parser.add_argument("-v", "--verbosity-level", action="store", type=int, - help="") - parser.add_argument("-s", "--source", type=str, default="leetcode", - choices=["leetcode", "atcoder", "codewars",], - help="which data source to gather from.") - parser.add_argument("-d", "--data", type=str, default="question", - choices=["question", "q", "solutions", "sol", "s", "starter", "tests", "t"], - help="which type of data to receive.") - parser.add_argument("-n", "--number", type=int, default=0, - help="which problem to query.") + parser.add_argument("-v", "--verbosity-level", action="store", type=int, help="") + parser.add_argument( + "-s", + "--source", + type=str, + default="leetcode", + choices=[ + "leetcode", + "atcoder", + "codewars", + ], + help="which data source to gather from.", + ) + parser.add_argument( + "-d", + "--data", + type=str, + default="question", + choices=["question", "q", "solutions", "sol", "s", "starter", "tests", "t"], + help="which type of data to receive.", + ) + parser.add_argument( + "-n", "--number", type=int, default=0, help="which problem to query." + ) args = parser.parse_args() return args @@ -90,18 +111,18 @@ def get_valid_problems(data_dir="leetcode"): for folder in tmp: prob_path = os.path.join(root, folder) files = os.listdir(prob_path) - #TODO add more validity checks + # TODO add more validity checks if "input_output.json" in files or "sols.json" in files: valid_probs.append(prob_path) valid_probs = sorted(valid_probs) - #with open(os.path.join(args.source,"valid_problems.json"), "w") as f: + # with open(os.path.join(args.source,"valid_problems.json"), "w") as f: # json.dump(valid_probs, f) return valid_probs def get_question(problem_list, prob_index): root = problem_list[prob_index] - #print("get q", root) + # print("get q", root) if os.path.exists(os.path.join(root, "question.txt")): with open(os.path.join(root, "question.txt")) as f: question = f.readlines() @@ -120,8 +141,13 @@ def get_solutions(problem_list, prob_index): return sols -def run_test(problem=None, problem_list:List[str]=None, prob_index:int=None, - test:str=None, debug:bool=False): +def run_test( + problem=None, + problem_list: List[str] = None, + prob_index: int = None, + test: str = None, + debug: bool = False, +): """ if test is not None it'll try to run the code. otherwise it'll just return an input and output pair. @@ -133,11 +159,10 @@ def run_test(problem=None, problem_list:List[str]=None, prob_index:int=None, if problem_list is not None: root = problem_list[prob_index] - in_outs = problem["input_output"] if debug: print(f"test cases json = {in_outs['inputs']} {in_outs['outputs']}") - + if in_outs.get("fn_name") is None: which_type = CODE_TYPE.standard_input # Standard input method_name = None @@ -146,21 +171,21 @@ def run_test(problem=None, problem_list:List[str]=None, prob_index:int=None, method_name = in_outs["fn_name"] if debug: print(f"loaded json = {datetime.now().time()}") - + if test is None: return in_outs elif test is not None: # Disable functionalities that can make destructive changes to the test. reliability_guard() - + results = [] sol = "import sys\nimport time\nimport itertools\nfrom itertools import accumulate, product, permutations, combinations\nimport collections\nfrom collections import Counter, OrderedDict, deque, defaultdict, ChainMap\nfrom functools import lru_cache\nimport math\nfrom math import sqrt, sin, cos, tan, ceil, fabs, floor, gcd, exp, log, log2\nimport fractions\nfrom typing import List, Tuple\nimport numpy as np\nimport random\nimport heapq\nfrom heapq import *\n" if debug: print(f"loading test code = {datetime.now().time()}") - + if which_type == CODE_TYPE.call_based: sol += test - if debug: # or True: + if debug: # or True: print(f"sol = {sol}") signal.alarm(timeout) try: @@ -188,7 +213,7 @@ def run_test(problem=None, problem_list:List[str]=None, prob_index:int=None, else: new_test.append(x + "\n") tmp_test = new_test - + new_test = "" started = False for i in tmp_test: @@ -197,7 +222,7 @@ def run_test(problem=None, problem_list:List[str]=None, prob_index:int=None, new_test += "def code():\n" new_test += i started = True - elif started and ((i.startswith("from ")) or (i.startswith("import "))): + elif started and ((i.startswith("from ")) or (i.startswith("import "))): new_test += "\t" + i else: new_test += i @@ -206,7 +231,7 @@ def run_test(problem=None, problem_list:List[str]=None, prob_index:int=None, sol += tmp_test if debug: print(f"sol = {sol}") - # print(f"{o}") + # print(f"{o}") method_name = "code" signal.alarm(timeout) try: @@ -221,7 +246,7 @@ def run_test(problem=None, problem_list:List[str]=None, prob_index:int=None, signal.alarm(0) if debug: print(f"get method = {datetime.now().time()}") - + try: method = getattr(tmp, method_name) # get_attr second arg must be str except: @@ -235,22 +260,28 @@ def run_test(problem=None, problem_list:List[str]=None, prob_index:int=None, # JSON forces dictionaries to have string keys; this undoes this (assuming a singleton list) try: if isinstance(inputs[0], dict): - inputs = [{int(k): v for k,v in inputs[0].items()}] + inputs = [{int(k): v for k, v in inputs[0].items()}] except: True try: if isinstance(in_outs["outputs"][index], dict): - in_outs["outputs"][index] = [{int(k): v for k,v in in_outs["outputs"][index].items()}] + in_outs["outputs"][index] = [ + {int(k): v for k, v in in_outs["outputs"][index].items()} + ] except: True try: if isinstance(in_outs["outputs"][index][0], dict): - in_outs["outputs"][index] = [{int(k): v for k,v in in_outs["outputs"][index][0].items()}] + in_outs["outputs"][index] = [ + {int(k): v for k, v in in_outs["outputs"][index][0].items()} + ] except: True if debug: - print(f"time: {datetime.now().time()} testing index = {index} inputs = {inputs}, {type(inputs)}. type = {which_type}") + print( + f"time: {datetime.now().time()} testing index = {index} inputs = {inputs}, {type(inputs)}. type = {which_type}" + ) if which_type == CODE_TYPE.call_based: # Call-based signal.alarm(timeout) faulthandler.enable() @@ -262,15 +293,23 @@ def run_test(problem=None, problem_list:List[str]=None, prob_index:int=None, # ground truth sequences are not tuples if isinstance(output, tuple): output = list(output) - + tmp_result = output == in_outs["outputs"][index] - if isinstance(in_outs["outputs"][index], list) and in_outs["outputs"][index]: - tmp_result = tmp_result or (output == in_outs["outputs"][index][0]) + if ( + isinstance(in_outs["outputs"][index], list) + and in_outs["outputs"][index] + ): + tmp_result = tmp_result or ( + output == in_outs["outputs"][index][0] + ) # ground truth sequences are not tuples try: if isinstance(output[0], tuple): - tmp_result = tmp_result or ([list(x) for x in output] == in_outs["outputs"][index][0]) + tmp_result = tmp_result or ( + [list(x) for x in output] + == in_outs["outputs"][index][0] + ) except: True results.append(tmp_result) @@ -280,13 +319,17 @@ def run_test(problem=None, problem_list:List[str]=None, prob_index:int=None, except Exception as e: signal.alarm(0) faulthandler.disable() - print(f"Standard input runtime error or time limit exceeded error = {e}") + print( + f"Standard input runtime error or time limit exceeded error = {e}" + ) results.append(-1) continue faulthandler.disable() signal.alarm(0) if debug: - print(f"outputs = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}") + print( + f"outputs = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" + ) elif which_type == CODE_TYPE.standard_input: # Standard input faulthandler.enable() signal.alarm(timeout) @@ -294,8 +337,8 @@ def run_test(problem=None, problem_list:List[str]=None, prob_index:int=None, if isinstance(inputs, list): inputs = "\n".join(inputs) - if isinstance(in_outs['outputs'][index], list): - in_outs['outputs'][index] = "\n".join(in_outs['outputs'][index]) + if isinstance(in_outs["outputs"][index], list): + in_outs["outputs"][index] = "\n".join(in_outs["outputs"][index]) with Capturing() as output: try: @@ -306,7 +349,9 @@ def run_test(problem=None, problem_list:List[str]=None, prob_index:int=None, except Exception as e: # runtime error or took too long signal.alarm(0) - print(f"Call-based runtime error or time limit exceeded error = {repr(e)}{e}") + print( + f"Call-based runtime error or time limit exceeded error = {repr(e)}{e}" + ) results.append(-1) signal.alarm(0) @@ -314,15 +359,21 @@ def run_test(problem=None, problem_list:List[str]=None, prob_index:int=None, if debug: nl = "\n" if not isinstance(inputs, list): - print(f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}") + print( + f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" + ) else: - print(f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}") + print( + f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" + ) continue if passed and debug: - print(f"==> output = {output}, test outputs = {in_outs['outputs'][index]}") + print( + f"==> output = {output}, test outputs = {in_outs['outputs'][index]}" + ) - if custom_compare_(output, in_outs['outputs'][index]): + if custom_compare_(output, in_outs["outputs"][index]): tmp_result = True results.append(tmp_result) continue @@ -333,16 +384,18 @@ def run_test(problem=None, problem_list:List[str]=None, prob_index:int=None, tmp_result = False try: - tmp_result = (output == [in_outs["outputs"][index]]) + tmp_result = output == [in_outs["outputs"][index]] if isinstance(in_outs["outputs"][index], list): tmp_result = tmp_result or (output == in_outs["outputs"][index]) if isinstance(output[0], str): - tmp_result = tmp_result or ([e.strip() for e in output] == in_outs["outputs"][index]) + tmp_result = tmp_result or ( + [e.strip() for e in output] == in_outs["outputs"][index] + ) except Exception as e: print(f"Failed check1 exception = {e}") pass - if tmp_result == True: + if tmp_result == True: results.append(tmp_result) continue @@ -350,14 +403,20 @@ def run_test(problem=None, problem_list:List[str]=None, prob_index:int=None, if isinstance(in_outs["outputs"][index], list): for tmp_index, i in enumerate(in_outs["outputs"][index]): in_outs["outputs"][index][tmp_index] = i.split("\n") - in_outs["outputs"][index][tmp_index] = [x.strip() for x in in_outs["outputs"][index][tmp_index] if x] + in_outs["outputs"][index][tmp_index] = [ + x.strip() for x in in_outs["outputs"][index][tmp_index] if x + ] else: in_outs["outputs"][index] = in_outs["outputs"][index].split("\n") - in_outs["outputs"][index] = list(filter(len, in_outs["outputs"][index])) - in_outs["outputs"][index] = list(map(lambda x:x.strip(), in_outs["outputs"][index])) + in_outs["outputs"][index] = list( + filter(len, in_outs["outputs"][index]) + ) + in_outs["outputs"][index] = list( + map(lambda x: x.strip(), in_outs["outputs"][index]) + ) try: - tmp_result = (output == [in_outs["outputs"][index]]) + tmp_result = output == [in_outs["outputs"][index]] if isinstance(in_outs["outputs"][index], list): tmp_result = tmp_result or (output == in_outs["outputs"][index]) except Exception as e: @@ -375,16 +434,20 @@ def run_test(problem=None, problem_list:List[str]=None, prob_index:int=None, if debug: nl = "\n" if not isinstance(inputs, list): - print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}") + print( + f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" + ) else: - print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}") - + print( + f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" + ) + if tmp_result == True: results.append(tmp_result) continue try: - tmp_result = (output == [in_outs["outputs"][index]]) + tmp_result = output == [in_outs["outputs"][index]] if isinstance(in_outs["outputs"][index], list): tmp_result = tmp_result or (output == in_outs["outputs"][index]) except Exception as e: @@ -393,16 +456,22 @@ def run_test(problem=None, problem_list:List[str]=None, prob_index:int=None, try: output_float = [float(e) for e in output] - gt_float = [float(e) for e in in_outs['outputs'][index]] - tmp_result = tmp_result or ((len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float)) - except Exception as e: + gt_float = [float(e) for e in in_outs["outputs"][index]] + tmp_result = tmp_result or ( + (len(output_float) == len(gt_float)) + and np.allclose(output_float, gt_float) + ) + except Exception: pass try: if isinstance(output[0], list): output_float = [float(e) for e in output[0]] - gt_float = [float(e) for e in in_outs['outputs'][index][0]] - tmp_result = tmp_result or ((len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float)) - except Exception as e: + gt_float = [float(e) for e in in_outs["outputs"][index][0]] + tmp_result = tmp_result or ( + (len(output_float) == len(gt_float)) + and np.allclose(output_float, gt_float) + ) + except Exception: pass if tmp_result == True: @@ -417,14 +486,14 @@ def run_test(problem=None, problem_list:List[str]=None, prob_index:int=None, in_outs["outputs"][index] = set(in_outs["outputs"][index].split()) try: - tmp_result = (output == in_outs["outputs"][index]) + tmp_result = output == in_outs["outputs"][index] except Exception as e: print(f"Failed check4 exception = {e}") continue if tmp_result == True: results.append(tmp_result) - continue + continue # try by converting the output into a split up list too if isinstance(output, list): @@ -432,42 +501,51 @@ def run_test(problem=None, problem_list:List[str]=None, prob_index:int=None, output[tmp_index] = i.split() output = list(filter(len, output)) for tmp_index, i in enumerate(output): - output[tmp_index] = set(i) + output[tmp_index] = set(i) else: output = output.split() output = list(filter(len, output)) output = set(output) try: - tmp_result = (set(frozenset(s) for s in output) == set(frozenset(s) for s in in_outs["outputs"][index])) + tmp_result = set(frozenset(s) for s in output) == set( + frozenset(s) for s in in_outs["outputs"][index] + ) except Exception as e: print(f"Failed check5 exception = {e}") - # if they are all numbers, round so that similar numbers are treated as identical try: - tmp_result = tmp_result or (set(frozenset(round(float(t),3) for t in s) for s in output) ==\ - set(frozenset(round(float(t),3) for t in s) for s in in_outs["outputs"][index])) + tmp_result = tmp_result or ( + set(frozenset(round(float(t), 3) for t in s) for s in output) + == set( + frozenset(round(float(t), 3) for t in s) + for s in in_outs["outputs"][index] + ) + ) except Exception as e: print(f"Failed check6 exception = {e}") - + if tmp_result == True and debug: print("PASSED") - + results.append(tmp_result) - + if debug: nl = "\n" if not isinstance(inputs, list): - print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}") + print( + f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" + ) else: - print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}") - + print( + f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" + ) return results + def custom_compare_(output, ground_truth): - if isinstance(output, list): output_1 = "\n".join(output) if stripped_string_compare(output_1, ground_truth): @@ -481,13 +559,14 @@ def custom_compare_(output, ground_truth): return False + def stripped_string_compare(s1, s2): s1 = s1.lstrip().rstrip() s2 = s2.lstrip().rstrip() return s1 == s2 -def call_method(method, inputs): +def call_method(method, inputs): if isinstance(inputs, list): inputs = "\n".join(inputs) @@ -496,20 +575,22 @@ def call_method(method, inputs): # sys.setrecursionlimit(10000) # @patch('builtins.input', side_effect=inputs.split("\n")) - @patch('builtins.open', mock_open(read_data=inputs)) - @patch('sys.stdin', StringIO(inputs)) - @patch('sys.stdin.readline', lambda *args: next(inputs_line_iterator)) - @patch('sys.stdin.readlines', lambda *args: inputs.split("\n")) - @patch('sys.stdin.read', lambda *args: inputs) + @patch("builtins.open", mock_open(read_data=inputs)) + @patch("sys.stdin", StringIO(inputs)) + @patch("sys.stdin.readline", lambda *args: next(inputs_line_iterator)) + @patch("sys.stdin.readlines", lambda *args: inputs.split("\n")) + @patch("sys.stdin.read", lambda *args: inputs) # @patch('sys.stdout.write', print) def _inner_call_method(_method): try: return _method() - except SystemExit as e: + except SystemExit: pass finally: pass - return _inner_call_method(method) + + return _inner_call_method(method) + def reliability_guard(maximum_memory_bytes=None): """ @@ -527,10 +608,16 @@ def reliability_guard(maximum_memory_bytes=None): if maximum_memory_bytes is not None: import resource - resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)) - resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)) + resource.setrlimit( + resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes) + ) + resource.setrlimit( + resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes) + ) if not platform.uname().system == "Darwin": - resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)) + resource.setrlimit( + resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes) + ) faulthandler.disable() @@ -591,6 +678,7 @@ def reliability_guard(maximum_memory_bytes=None): sys.modules["psutil"] = None sys.modules["tkinter"] = None + def main(args): print(args) problem_list = sorted(get_valid_problems(args.source)) @@ -602,7 +690,11 @@ def main(args): if args.data == "q" or args.data == "question": tmp = get_question(problem_list, prob_index) print("q", tmp) - elif args.data in ["solutions", "sol", "s",]: + elif args.data in [ + "solutions", + "sol", + "s", + ]: tmp = get_solutions(problem_list, prob_index) print("sol", tmp) elif args.data == "starter": @@ -614,7 +706,10 @@ def main(args): tmp = run_test(problem_list, prob_index, test=sols[0]) print("results = ", tmp) - print("-2 = compile error, -1 is runtime error, False failed test, True passed test") + print( + "-2 = compile error, -1 is runtime error, False failed test, True passed test" + ) + if __name__ == "__main__": args = parse_args() diff --git a/skythought/tools/tasks/common.py b/skythought/tools/tasks/common.py new file mode 100644 index 0000000..0b10d97 --- /dev/null +++ b/skythought/tools/tasks/common.py @@ -0,0 +1,210 @@ +import copy +import json +import multiprocessing +import os +import random +import re +import numpy as np +from datasets import load_dataset +from typing import Dict, Any +from multiprocessing import Manager +from tasks.apps.apps_util import run_test as apps_run_test +from tasks.taco.taco_util import run_test as taco_run_test +from ..util.math_parsing_util import strip_answer_string, get_multiple_choice_answer, extract_answer, math_equal, mmlu_pro_extract_answer +from tasks.livecodebench.livecodebench_util import unsafe_lcb_runTests, map_to_example, has_test_type, post_process_code, translate_private_test_cases +from ..util.common import TimeoutException, timeout +from util.model_utils import * + +def has_code(response): + pattern = r"```(?:[a-zA-Z]*)\n(.*?)```" + # Use re.DOTALL to match multiline content inside backticks + matches = re.findall(pattern, response, re.DOTALL) + # print(matches) + return matches + +class TaskHandler: + @staticmethod + def get_question_key(): + raise NotImplementedError("Subclasses should implement this method.") + + def check_correctness(self, problem, generation): + raise NotImplementedError("Subclasses should implement this method.") + + def update_results(self, problem, response): + raise NotImplementedError("Subclasses should implement this method.") + + def make_conversations(self, data, system_prompt, model=None): + raise NotImplementedError("Subclasses should implement this method.") + + def load_existing_results(self, result_file): + if not os.path.exists(result_file): + return {} + with open(result_file, 'r', encoding='utf-8') as f: + records = json.load(f) + return records + + def load_and_filter_dataset(self, start, end, split="train", source=None, filter_difficulty=False, args=None): + raise NotImplementedError("Subclasses should implement this method.") + + def process_remaining_data(self, train_data, results): + raise NotImplementedError("Subclasses should implement this method.") + + +class MathTaskHandler(TaskHandler): + @staticmethod + def generate_prompt(prompt): + return "Return your final response within \\boxed{{}}. " + prompt + + def check_correctness(self, problem, generation): + answer = strip_answer_string(problem["answer"]) + pred = extract_answer(generation) + # print(problem) + pred = strip_answer_string(pred) + return math_equal(pred, answer) + + def update_results(self, problem, response): + if not isinstance(response, str): + response = response.outputs[0].text.strip() + # Initialize the response structure + response_entry = { + "content": response, + "correctness": None, + "reason": None, + } + curr_res = self.check_correctness(problem, generation=response) + if curr_res: + response_entry["correctness"] = True + response_entry["reason"] = "" + else: + response_entry["correctness"] = False + response_entry["reason"] = "Solution is incorrect." + + return response_entry + + def make_conversations(self, data, system_prompt, model=None): + conversations = [] + for problem in data: + prompt_text = self.generate_prompt(problem["problem"]) + conversations.append([ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt_text} + ]) + return conversations + + def process_remaining_data(self, train_data, results): + return [row.to_dict() for _, row in train_data.iterrows() if str(row["problem"]) not in results] + + def load_and_filter_dataset(self, start, end, split="test", source=None, filter_difficulty=False, args=None): + dataset = load_dataset(self.dataset) + train_data = dataset[split].to_pandas() + return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] + + + +class ARCChallengeTaskHandler(TaskHandler): + def __init__(self) -> None: + super().__init__() + self.dataset = "allenai/ai2_arc" + self.ans_re = re.compile(r"[Tt]he best answer is ([A-D])[\.\,]*", re.IGNORECASE) + self.letter_re = re.compile(r"([A-D])[\.\,]*") + self.canonical_options = ["A", "B", "C", "D"] + self.invalid_ans = "[invalid]" + + @staticmethod + def get_question_key(): + return "question" + + @staticmethod + def generate_prompt(problem): + question = problem["question"] + choices = problem["choices"] + choices_text = '\n'.join([f"{label}.{choice}" for label, choice in zip(["A", "B", "C", "D"], choices["text"])]) + full_prompt = "Given the following question and four candidate answers (A, B, C and D), choose the best answer. Your response should end with \"The best answer is [the_answer_letter]\" where [the_answer_letter] is one of the four letter choice (A, B, C, or D).\n" + f"{question}\n{choices_text}" + return full_prompt + + def check_correctness(self, problem: Dict[str, Any], generation: str) -> bool: + gt_answer = problem["answerKey"] + if gt_answer not in self.canonical_options: + gt_answer = self.canonical_options[int(problem["answerKey"]) - 1] + model_answer = self.get_answer(generation) + return model_answer == gt_answer + + def update_results(self, problem, response): + if not isinstance(response, str): + response = response.outputs[0].text.strip() + # Initialize the response structure + response_entry = { + "content": response, + "correctness": None, + "reason": None, + } + curr_res = self.check_correctness(problem, generation=response) + if curr_res: + response_entry["correctness"] = True + response_entry["reason"] = "" + else: + response_entry["correctness"] = False + response_entry["reason"] = "Solution is incorrect." + + return response_entry + + def make_conversations(self, data, system_prompt, model=None): + conversations = [] + for problem in data: + prompt_text = self.generate_prompt(problem) + conversations.append([ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt_text} + ]) + return conversations + + def load_and_filter_dataset(self, start, end, split="train", source=None, filter_difficulty=False, args=None): + dataset = load_dataset(self.dataset, "ARC-Challenge") + train_data = dataset[split].to_pandas() + return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] + + def process_remaining_data(self, train_data, results): + return [row.to_dict() for _, row in train_data.iterrows() if str(row["question"]) not in results] + + def get_answer(self, completion): + # First, we try to extract similar to MATH answers + answer = extract_answer(completion) + match = None + if answer: + # match for the letter answer needed. + match = self.letter_re.search(answer) + if match: + return match.group(1).strip() + + if not answer or not match: + # try basic-regex based search + patterns_to_remove = [ + ',', # Remove commas + r'\$', # Remove dollar signs + r'\.$' # Remove trailing period + r"\\", # Remove stray backslashes + r"\*", # Remove asterisks + ] + answer = completion + for pattern in patterns_to_remove: + answer = re.sub(pattern, '', answer) + matches = self.ans_re.findall(answer) + if not matches: + return self.invalid_ans + return matches[-1].strip() + + +TASK_HANDLERS = { + "NUMINA": NUMINATaskHandler, + "APPS": APPSTaskHandler, + "TACO": TACOTaskHandler, + "MATH500": MATH500TaskHandler, + "AIME": AIMETaskHandler, + "GPQADiamond": GPQADiamondTaskHandler, + "MMLU": MMLUTaskHandler, + "MMLUPro": MMLUProTaskHandler, + "LiveCodeBench": LiveCodeBenchTaskHandler, + "GSM8K": GSM8KTaskHandler, + "ARC-C": ARCChallengeTaskHandler, + "AMC23": AMC23TaskHandler, +} diff --git a/skythought/tools/tasks/gpqa_diamond/gpqa_diamond_handler.py b/skythought/tools/tasks/gpqa_diamond/gpqa_diamond_handler.py new file mode 100644 index 0000000..a4904fb --- /dev/null +++ b/skythought/tools/tasks/gpqa_diamond/gpqa_diamond_handler.py @@ -0,0 +1,90 @@ +import copy +import json +import multiprocessing +import os +import random +import re +import numpy as np +from datasets import load_dataset +from typing import Dict, Any +from multiprocessing import Manager +from util.model_utils import SYSTEM_PROMPT, MODEL_TO_NAME +from tasks.common import TaskHandler +from util.math_parsing_util import get_multiple_choice_answer, extract_answer, math_equal, mmlu_pro_extract_answer + +class GPQADiamondTaskHandler(TaskHandler): + def __init__(self): + self.dataset = "Idavidrein/gpqa" + + @staticmethod + def generate_prompt(prompt): + return "Return your final response within \\boxed{{}} and only include the letter choice (A, B, C, or D) as your final response. " + prompt + + @staticmethod + def get_question_key(): + return "Question" + + def update_results(self, problem, response): + if not isinstance(response, str): + response = response.outputs[0].text.strip() + # Initialize the response structure + response_entry = { + "content": response, + "correctness": None, + "reason": None, + } + curr_res = self.check_correctness(problem, generation=response) + if curr_res: + response_entry["correctness"] = True + response_entry["reason"] = "" + else: + response_entry["correctness"] = False + response_entry["reason"] = "Solution is incorrect." + + return response_entry + + def check_correctness(self, problem, generation): + pred = get_multiple_choice_answer(generation) + answer = problem["Answer"] + return answer == pred + + def get_multiple_choice_answers(self, data): + answers = [ + data["Correct Answer"], + data["Incorrect Answer 1"], + data["Incorrect Answer 2"], + data["Incorrect Answer 3"] + ] + random.shuffle(answers) + + # Map options to letters + options = ["A", "B", "C", "D"] + options_to_answers = {letter: answer for letter, answer in zip(options, answers)} + + # Format the options into the string + multiple_choice_string = ", ".join(f"{letter}) {options_to_answers[letter]}" for letter in options) + + # Save the letter corresponding to the correct answer + correct_answer_letter = next(letter for letter, answer in options_to_answers.items() if answer == data["Correct Answer"]) + + return multiple_choice_string, correct_answer_letter + + def make_conversations(self, data, system_prompt, model=None): + conversations = [] + for problem in data: + multiple_choice_string, correct_answer_letter = self.get_multiple_choice_answers(problem) + problem["Answer"] = correct_answer_letter + prompt_text = self.generate_prompt(problem["Question"] + "\n" + multiple_choice_string) + conversations.append([ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt_text} + ]) + return conversations + + def load_and_filter_dataset(self, start, end, split="train", source=None, filter_difficulty=False, args=None): + dataset = load_dataset(self.dataset, "gpqa_diamond") + train_data = dataset[split].to_pandas() + return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] + + def process_remaining_data(self, train_data, results): + return [row.to_dict() for _, row in train_data.iterrows() if str(row["Question"]) not in results] \ No newline at end of file diff --git a/skythought/tools/tasks/gsm8k/gsm8k_handler.py b/skythought/tools/tasks/gsm8k/gsm8k_handler.py new file mode 100644 index 0000000..07357cd --- /dev/null +++ b/skythought/tools/tasks/gsm8k/gsm8k_handler.py @@ -0,0 +1,100 @@ +import re +from datasets import load_dataset +from typing import Dict, Any +from multiprocessing import Manager +from tasks.apps.apps_util import run_test as apps_run_test +from tasks.taco.taco_util import run_test as taco_run_test +from util.math_parsing_util import strip_answer_string, get_multiple_choice_answer, extract_answer, math_equal, mmlu_pro_extract_answer +from tasks.livecodebench.livecodebench_util import unsafe_lcb_runTests, map_to_example, has_test_type, post_process_code, translate_private_test_cases +from util.common import TimeoutException, timeout +from util.model_utils import SYSTEM_PROMPT, MODEL_TO_NAME +from tasks.common import TaskHandler + + +class GSM8KTaskHandler(TaskHandler): + def __init__(self) -> None: + super().__init__() + self.dataset = "openai/gsm8k" + self.ans_re = re.compile(r"((-?[$0-9.,]{2,})|(-?[0-9]+))") + self.gt_re = re.compile(r"#### (\-?[0-9\.\,]+)") + self.invalid_ans = "[invalid]" + + @staticmethod + def get_question_key(): + return "question" + + @staticmethod + def generate_prompt(problem): + question = problem["question"] + full_prompt = f"Given the following problem, reason and give a final answer to the problem.\nProblem: {question}\nYour response should end with \"The final answer is [answer]\" where [answer] is the response to the problem." + return full_prompt + + def check_correctness(self, problem: Dict[str, Any], generation: str) -> bool: + gt_answer = self.extract_gt_answer(problem["answer"]) + model_answer = extract_answer(generation) + model_answer = self.sanitize_answer(model_answer) + return model_answer == gt_answer + + def update_results(self, problem, response): + if not isinstance(response, str): + response = response.outputs[0].text.strip() + # Initialize the response structure + response_entry = { + "content": response, + "correctness": None, + "reason": None, + } + curr_res= self.check_correctness(problem, generation=response) + if curr_res: + response_entry["correctness"] = True + response_entry["reason"] = "" + else: + response_entry["correctness"] = False + response_entry["reason"] = "Solution is incorrect." + + return response_entry + + def make_conversations(self, data, system_prompt, model=None): + conversations = [] + for problem in data: + prompt_text = self.generate_prompt(problem) + conversations.append([ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt_text} + ]) + return conversations + + def load_and_filter_dataset(self, start, end, split="train", source=None, filter_difficulty=False, args=None): + dataset = load_dataset(self.dataset, "main") + train_data = dataset[split].to_pandas() + return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] + + def process_remaining_data(self, train_data, results): + return [row.to_dict() for _, row in train_data.iterrows() if str(row["question"]) not in results] + + def extract_gt_answer(self, completion): + match = self.gt_re.search(completion) + if match: + match_str = match.group(1).strip() + match_str = match_str.replace(",", "") + return match_str + else: + return self.invalid_ans + + def sanitize_answer(self, answer): + patterns_to_remove = [ + ',', # Remove commas + r'\$', # Remove dollar signs + r'\.$' # Remove trailing period + r"\*", # Remove asterisks + ] + for pattern in patterns_to_remove: + answer = re.sub(pattern, '', answer) + + matches = self.ans_re.findall(answer) + if matches: + # get the last match (i.e final response) and the first / outer capturing group + match_str = matches[-1][0].strip() + return match_str + else: + return self.invalid_ans \ No newline at end of file diff --git a/skythought/tools/tasks/livecodebench/livecodebench_handler.py b/skythought/tools/tasks/livecodebench/livecodebench_handler.py index e69de29..ee51746 100644 --- a/skythought/tools/tasks/livecodebench/livecodebench_handler.py +++ b/skythought/tools/tasks/livecodebench/livecodebench_handler.py @@ -0,0 +1,137 @@ +import copy +from typing import Dict + +from datasets import load_dataset + +from tasks.common import TaskHandler +from tasks.livecodebench.livecodebench_util import ( + map_to_example, + post_process_code, + translate_private_test_cases, + unsafe_lcb_runTests, +) +from util.common import has_code + + +class LiveCodeBenchTaskHandler(TaskHandler): + @staticmethod + def generate_prompt(problem): + # print(problem) + prompt = problem["prompt"] + if problem["is_stdin"]: + return ( + "Generate an executable Python function generated from the given prompt. The function should take stdin as input and print the output. Simply call the function after the definition." + + prompt + ) + else: + return ( + "Generate an executable Python function generated from the given prompt. Return the function body without invoking it at the final solution." + + prompt + ) + + @staticmethod + def get_question_key(): + return "task_id" + + def check_correctness( + self, + problem: Dict, + completion: str, + timeout: float, + runtime_debug=False, + is_extracted=False, + ) -> Dict: + """ + Evaluates the functional correctness of a completion by running the test + suite provided in the problem. + + :param completion_id: an optional completion ID so we can match + the results later even if execution finishes asynchronously. + """ + result_list = unsafe_lcb_runTests( + problem, completion, timeout, runtime_debug, is_extracted + ) + details = [r[0] for r in result_list] + all_passed = all(details) + + result = "" + if result_list and all_passed: + result = "passed" + + return result == "passed" + + def update_results(self, problem, response): + if not isinstance(response, str): + response = response.outputs[0].text.strip() + # Initialize the response structure + response_entry = { + "content": response, + "correctness": None, + "reason": None, + } + code_filter_result = has_code(response) + # print(response) + if len(code_filter_result) == 0: + response_entry["correctness"] = False + response_entry["reason"] = "Does not contain code component." + else: + last_code = code_filter_result[-1] + problem_to_check = copy.deepcopy(problem) + + curr_res = self.check_correctness( + problem=problem_to_check, + completion=post_process_code(last_code), + timeout=6, + is_extracted=not problem_to_check["is_stdin"], + ) + if curr_res: + response_entry["correctness"] = True + response_entry["reason"] = "" + else: + response_entry["correctness"] = False + response_entry["reason"] = "Code is incorrect." + + return response_entry + + def make_conversations(self, data, system_prompt, model=None): + conversations = [] + for problem in data: + prompt_text = self.generate_prompt(problem) + conversations.append( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt_text}, + ] + ) + return conversations + + def load_and_filter_dataset( + self, start, end, split="test", source=None, filter_difficulty=False, args=None + ): + dataset = load_dataset( + "livecodebench/code_generation_lite", + version_tag="release_v2", + split=split, + trust_remote_code=True, + ) + if filter_difficulty: + dataset = dataset.filter(lambda example: example["difficulty"] == source) + dataset = dataset.map( + lambda example: { + "private_test_cases": translate_private_test_cases( + example["private_test_cases"] + ) + } + ) + # Apply the mapping function + dataset = dataset.map( + map_to_example, remove_columns=dataset.column_names + ).to_pandas() + return dataset.iloc[start:end] if end > 0 else dataset.iloc[start:] + + def process_remaining_data(self, train_data, results): + return [ + row.to_dict() + for _, row in train_data.iterrows() + if str(row["task_id"]) not in results + ] diff --git a/skythought/tools/tasks/livecodebench/livecodebench_util.py b/skythought/tools/tasks/livecodebench/livecodebench_util.py index ab682e3..bc00fb6 100644 --- a/skythought/tools/tasks/livecodebench/livecodebench_util.py +++ b/skythought/tools/tasks/livecodebench/livecodebench_util.py @@ -1,26 +1,23 @@ - -from typing import Optional, Callable, Dict -import ast -import copy +import base64 +import builtins import contextlib +import copy import faulthandler import io -from io import StringIO -import os -import multiprocessing -import platform -import signal -import tempfile import json -import sys -import builtins +import multiprocessing +import os +import pickle import shutil +import signal import subprocess +import sys +import tempfile import time -import base64 import zlib -import pickle -import scipy.stats as stats +from io import StringIO +from typing import Optional + def post_process_code(code): code = code.split("")[0] @@ -29,11 +26,12 @@ def post_process_code(code): code = code.replace("", "") # print(f"postprocessed code: {code}") return code - + + def post_process_tests_inputs(raw_text, is_stdin): # raw_text = raw_text.strip().strip("```json").strip("```").strip() # raw_text.strip() # print(raw_text) - if is_stdin: + if is_stdin: blocks = raw_text.split("Input:") formatted_tests = [] @@ -55,7 +53,7 @@ def post_process_tests_inputs(raw_text, is_stdin): "testtype": "stdin", } ) - return formatted_tests + return formatted_tests else: # Step 1: Clean the input string by removing surrounding markdown syntax and extra spaces cleaned_string = raw_text.strip().strip("```json").strip("```").strip() @@ -103,6 +101,7 @@ def post_process_tests_inputs(raw_text, is_stdin): return test_cases + def prepare_test_input_output_functional(test_case, is_extracted): if not is_extracted: # Extract input and expected output from JSON directly @@ -177,6 +176,7 @@ def prepare_test_input_output_functional(test_case, is_extracted): expected_output = expected_output.strip() return inputs, expected_output + def prepare_test_input_output_std(test_case): test_input = test_case["input"] test_output = test_case["output"].strip() @@ -186,6 +186,7 @@ def prepare_test_input_output_std(test_case): ].rstrip() # Remove '-' if present and trailing return test_input, test_output + def run_test_func(completion, is_extracted, test_input, test_output): # print(f"inside: {completion}") if not is_extracted: @@ -261,6 +262,7 @@ def run_test_func(completion, is_extracted, test_input, test_output): return True, result_output + def run_test_std(completion, test_input, test_output): sys.stdin = StringIO(test_input) @@ -269,30 +271,39 @@ def run_test_std(completion, test_input, test_output): if '__name__ == "__main__"' in completion: # Simulate that the code is being run as the main script - completion = f'__name__ = "__main__"\n' + completion - + completion = '__name__ = "__main__"\n' + completion + namespace = {} exec(completion, namespace) output_value = output.getvalue().strip() return output_value == test_output, output_value + def unsafe_lcb_runTests(problem, completion, timeout, runtime_debug, is_extracted): test_cases = problem["test"] manager = multiprocessing.Manager() result = manager.list() - p = multiprocessing.Process(target=run_tests_for_one_example, args=(test_cases, completion, result, runtime_debug, is_extracted)) + p = multiprocessing.Process( + target=run_tests_for_one_example, + args=(test_cases, completion, result, runtime_debug, is_extracted), + ) p.start() - p.join(timeout = (timeout+1) * len(test_cases) + 5) # TODO Alex: Check whether number of task cases is correct + p.join( + timeout=(timeout + 1) * len(test_cases) + 5 + ) # TODO Alex: Check whether number of task cases is correct if p.is_alive(): p.kill() - + # if len(result) < len(test_cases): ## This is supposed to be the case where not all test passed in the given timeout for i in range(len(test_cases) - len(result)): - result.append((False, f"Time out!.", "Error: Time out!", float("inf"))) + result.append((False, "Time out!.", "Error: Time out!", float("inf"))) return result -def run_tests_for_one_example(test_cases, completion, result_list, runtime_debug, is_extracted): + +def run_tests_for_one_example( + test_cases, completion, result_list, runtime_debug, is_extracted +): time_elapsed = float("inf") test_type = test_cases[0]["testtype"] reliability_guard() @@ -302,13 +313,20 @@ def run_tests_for_one_example(test_cases, completion, result_list, runtime_debug try: time_start = time.time() if test_type == "functional": - test_input, test_output = prepare_test_input_output_functional(test_case, is_extracted) + test_input, test_output = prepare_test_input_output_functional( + test_case, is_extracted + ) passed, output_value = run_test_func( - completion, is_extracted, copy.deepcopy(test_input), copy.deepcopy(test_output) + completion, + is_extracted, + copy.deepcopy(test_input), + copy.deepcopy(test_output), ) else: test_input, test_output = prepare_test_input_output_std(test_case) - passed, output_value = run_test_std(completion, copy.deepcopy(test_input), copy.deepcopy(test_output)) + passed, output_value = run_test_std( + completion, copy.deepcopy(test_input), copy.deepcopy(test_output) + ) # print(test_input, test_output, output_value) if not passed: output_error = f"For test input: {test_input}. Expected output is: {test_output}, but got: {output_value}." @@ -322,11 +340,12 @@ def run_tests_for_one_example(test_cases, completion, result_list, runtime_debug output_value = f"Error: {e}." if output_error == "": output_error = f"For test input: {test_input}. Expected output is: {test_output}, your solution correctly passes this test with output {output_value}." - + result_list.append((passed, output_error, output_value, time_elapsed)) if not passed: return + @contextlib.contextmanager def time_limit(seconds: float): def signal_handler(signum, frame): @@ -339,6 +358,7 @@ def signal_handler(signum, frame): finally: signal.setitimer(signal.ITIMER_REAL, 0) + @contextlib.contextmanager def swallow_io(redirect_input=True): """ @@ -354,15 +374,18 @@ def swallow_io(redirect_input=True): else: yield stream # Do not redirect stdin + @contextlib.contextmanager def create_tempdir(): with tempfile.TemporaryDirectory() as dirname: with chdir(dirname): yield dirname + class TimeoutException(Exception): pass + class WriteOnlyStringIO(io.StringIO): """StringIO that throws an exception when it's read from""" @@ -379,9 +402,11 @@ def readable(self, *args, **kwargs): """Returns True if the IO object can be read.""" return False + class redirect_stdin(contextlib._RedirectStream): # type: ignore _stream = "stdin" + @contextlib.contextmanager def chdir(root): if root == ".": @@ -539,7 +564,8 @@ def restore_original_references(): for module_name, original_module in originals["sys_modules"].items(): if original_module is not None: sys.modules[module_name] = original_module - + + def has_test_type(tests, type): ## helper to select specific type of problems """ Check if any test in the test list has 'testtype' set to 'type'. @@ -550,12 +576,14 @@ def has_test_type(tests, type): ## helper to select specific type of problems return True return False + def translate_private_test_cases(encoded_data): decoded_data = base64.b64decode(encoded_data) decompressed_data = zlib.decompress(decoded_data) original_data = pickle.loads(decompressed_data) return json.loads(original_data) + """ def update_dataset_in_place( dataset, @@ -574,6 +602,7 @@ def update_dataset_in_place( # break """ + def map_to_example(row): return { "prompt": row["question_content"], @@ -583,5 +612,5 @@ def map_to_example(row): "task_id": row["question_id"], "is_stdin": has_test_type(row["public_test_cases"], "stdin"), "public_test_cases": row["public_test_cases"], - "difficulty": row["difficulty"] - } \ No newline at end of file + "difficulty": row["difficulty"], + } diff --git a/skythought/tools/tasks/math500/math500_handler.py b/skythought/tools/tasks/math/math500.yaml similarity index 100% rename from skythought/tools/tasks/math500/math500_handler.py rename to skythought/tools/tasks/math/math500.yaml diff --git a/skythought/tools/tasks/math/math500_handler.py b/skythought/tools/tasks/math/math500_handler.py new file mode 100644 index 0000000..00571bb --- /dev/null +++ b/skythought/tools/tasks/math/math500_handler.py @@ -0,0 +1,10 @@ +from datasets import load_dataset +from typing import Dict, Any +from multiprocessing import Manager +from tasks.apps.apps_util import run_test as apps_run_test +from tasks.taco.taco_util import run_test as taco_run_test +from util.math_parsing_util import strip_answer_string, get_multiple_choice_answer, extract_answer, math_equal, mmlu_pro_extract_answer +from tasks.livecodebench.livecodebench_util import unsafe_lcb_runTests, map_to_example, has_test_type, post_process_code, translate_private_test_cases +from util.common import TimeoutException, timeout +from util.model_utils import SYSTEM_PROMPT, MODEL_TO_NAME +from tasks.common import MathTaskHandler diff --git a/skythought/tools/tasks/math/math_handler.py b/skythought/tools/tasks/math/math_handler.py new file mode 100644 index 0000000..82f41c1 --- /dev/null +++ b/skythought/tools/tasks/math/math_handler.py @@ -0,0 +1,69 @@ +from datasets import load_dataset +from typing import Dict, Any +from multiprocessing import Manager +from tasks.apps.apps_util import run_test as apps_run_test +from tasks.taco.taco_util import run_test as taco_run_test +from util.math_parsing_util import strip_answer_string, get_multiple_choice_answer, extract_answer, math_equal, mmlu_pro_extract_answer +from tasks.livecodebench.livecodebench_util import unsafe_lcb_runTests, map_to_example, has_test_type, post_process_code, translate_private_test_cases +from util.common import TimeoutException, timeout +from util.model_utils import SYSTEM_PROMPT, MODEL_TO_NAME +from tasks.common import TaskHandler + + +class MathTaskHandler(TaskHandler): + @staticmethod + def generate_prompt(prompt): + return "Return your final response within \\boxed{{}}. " + prompt + + def check_correctness(self, problem, generation): + answer = strip_answer_string(problem["answer"]) + pred = extract_answer(generation) + # print(problem) + pred = strip_answer_string(pred) + return math_equal(pred, answer) + + def update_results(self, problem, response): + if not isinstance(response, str): + response = response.outputs[0].text.strip() + # Initialize the response structure + response_entry = { + "content": response, + "correctness": None, + "reason": None, + } + curr_res = self.check_correctness(problem, generation=response) + if curr_res: + response_entry["correctness"] = True + response_entry["reason"] = "" + else: + response_entry["correctness"] = False + response_entry["reason"] = "Solution is incorrect." + + return response_entry + + def make_conversations(self, data, system_prompt, model=None): + conversations = [] + for problem in data: + prompt_text = self.generate_prompt(problem["problem"]) + conversations.append([ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt_text} + ]) + return conversations + + def process_remaining_data(self, train_data, results): + return [row.to_dict() for _, row in train_data.iterrows() if str(row["problem"]) not in results] + + def load_and_filter_dataset(self, start, end, split="test", source=None, filter_difficulty=False, args=None): + dataset = load_dataset(self.dataset) + train_data = dataset[split].to_pandas() + return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] + + +class MATH500TaskHandler(MathTaskHandler): + def __init__(self): + self.dataset = "qq8933/MATH500" + + @staticmethod + def get_question_key(): + return "problem" diff --git a/skythought/tools/tasks/mmlu/mmlu_handler.py b/skythought/tools/tasks/mmlu/mmlu_handler.py new file mode 100644 index 0000000..bd293f0 --- /dev/null +++ b/skythought/tools/tasks/mmlu/mmlu_handler.py @@ -0,0 +1,113 @@ +import copy +import json +import multiprocessing +import os +import random +import re +import numpy as np +from datasets import load_dataset +from typing import Dict, Any +from multiprocessing import Manager +from tasks.apps.apps_util import run_test as apps_run_test +from tasks.taco.taco_util import run_test as taco_run_test +from util.math_parsing_util import strip_answer_string, get_multiple_choice_answer, extract_answer, math_equal, mmlu_pro_extract_answer +from tasks.livecodebench.livecodebench_util import unsafe_lcb_runTests, map_to_example, has_test_type, post_process_code, translate_private_test_cases +from util.common import TimeoutException, timeout +from util.model_utils import SYSTEM_PROMPT + +from ..common import TaskHandler + +class MMLUTaskHandler(TaskHandler): + def __init__(self): + self.dataset = "cais/mmlu" + + @staticmethod + def generate_prompt(prompt): + return "Return your final response within \\boxed{{}}. " + prompt + + @staticmethod + def get_question_key(): + return "question" + + def check_correctness(self, problem, generation): + pred = get_multiple_choice_answer(generation) + abcd = "ABCD" + answer = abcd[problem["answer"]] + return answer == pred + + def update_results(self, problem, response): + if not isinstance(response, str): + response = response.outputs[0].text.strip() + # Initialize the response structure + response_entry = { + "content": response, + "correctness": None, + "reason": None, + } + curr_res = self.check_correctness(problem, generation=response) + if curr_res: + response_entry["correctness"] = True + response_entry["reason"] = "" + else: + response_entry["correctness"] = False + response_entry["reason"] = "Solution is incorrect." + return response_entry + + def get_multiple_choice_answers(self, problem): + options = problem["choices"] + for i, (label, option) in enumerate(zip("ABCD", options)): + options[i] = f"({label}) {str(option).strip()}" + options = " ".join(options) + return f"Answer Choices: {options}" + + def make_conversations(self, data, system_prompt, model=None): + conversations = [] + for problem in data: + multiple_choice_string = self.get_multiple_choice_answers(problem) + prompt_text = self.generate_prompt(problem["question"] + "\n" + multiple_choice_string) + conversations.append([ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt_text} + ]) + return conversations + + def process_remaining_data(self, train_data, results): + return [row.to_dict() for _, row in train_data.iterrows() if str(row["question"]) not in results] + + def load_and_filter_dataset(self, start, end, split="test", source=None, filter_difficulty=False, args=None): + dataset = load_dataset(self.dataset, "all") + train_data = dataset[split].to_pandas() + return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] + + + +class MMLUProTaskHandler(MMLUTaskHandler): + def __init__(self): + super().__init__() + self.dataset = "TIGER-Lab/MMLU-Pro" + self.choices = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P"] + + @staticmethod + def generate_prompt(prompt): + return "Return your final response within \\boxed{{}}. " + prompt + + @staticmethod + def get_question_key(): + return "question" + + def check_correctness(self, problem, generation): + pred = mmlu_pro_extract_answer(generation) + answer = self.choices[problem["answer_index"]] + return answer == pred + + def get_multiple_choice_answers(self, problem): + options = problem["options"] + for i, (label, option) in enumerate(zip(self.choices[:len(options)], options)): + options[i] = f"({label}) {str(option).strip()}" + options = " ".join(options) + return f"Answer Choices: {options}" + + def load_and_filter_dataset(self, start, end, split="test", source=None, filter_difficulty=False, args=None): + dataset = load_dataset(self.dataset, "default") + train_data = dataset[split].to_pandas() + return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] \ No newline at end of file diff --git a/skythought/tools/tasks/numina/numina_handler.py b/skythought/tools/tasks/numina/numina_handler.py new file mode 100644 index 0000000..a086ae3 --- /dev/null +++ b/skythought/tools/tasks/numina/numina_handler.py @@ -0,0 +1,83 @@ +from datasets import load_dataset +from typing import Dict, Any +from multiprocessing import Manager +from tasks.apps.apps_util import run_test as apps_run_test +from tasks.taco.taco_util import run_test as taco_run_test +from util.math_parsing_util import strip_answer_string, get_multiple_choice_answer, extract_answer, math_equal, mmlu_pro_extract_answer +from tasks.livecodebench.livecodebench_util import unsafe_lcb_runTests, map_to_example, has_test_type, post_process_code, translate_private_test_cases +from util.common import TimeoutException, timeout +from util.model_utils import SYSTEM_PROMPT, MODEL_TO_NAME +from tasks.common import TaskHandler + +class NUMINATaskHandler(TaskHandler): + @staticmethod + def get_question_key(): + return "problem" + + @staticmethod + def generate_prompt(prompt): + return "Return your final response within \\boxed{{}}. " + prompt + + @timeout(5) # Add timeout of 5 seconds + def check_correctness(self, problem, generation): + solution = extract_answer(problem["solution"]) + solution = strip_answer_string(solution) + pred = extract_answer(generation) + pred = strip_answer_string(pred) + return math_equal(pred, solution) + + def update_results(self, problem, response): + if not isinstance(response, str): + response = response.outputs[0].text.strip() + # Initialize the response structure + response_entry = { + "content": response, + "correctness": None, + "reason": None, + } + + try: + curr_res = self.check_correctness(problem, generation=response) + if curr_res: + response_entry["correctness"] = True + response_entry["reason"] = "" + else: + response_entry["correctness"] = False + response_entry["reason"] = "Solution is incorrect." + except TimeoutException as e: + response_entry["correctness"] = False + response_entry["reason"] = str(e) + + return response_entry + + @staticmethod + def get_difficulty_dict(source, start, end): + diff_dict = {} + dataset = load_dataset("NovaSky-AI/labeled_numina_difficulty_859K", trust_remote_code=True, split="train") + for example in dataset: + # print(example) + diff_dict[example["problem"]] = example["gpt_difficulty_parsed"] + return diff_dict + + def make_conversations(self, data, system_prompt, model=None): + conversations = [] + for problem in data: + prompt_text = self.generate_prompt(problem["problem"]) + conversations.append([ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt_text} + ]) + return conversations + + def load_and_filter_dataset(self, start, end, split="train", source=None, filter_difficulty=False, args=None): + dataset = load_dataset("AI-MO/NuminaMath-CoT") + train_data = dataset[split].to_pandas() + train_data = train_data.query('source == @source').iloc[start:end] if end > 0 else train_data.query('source == @source').iloc[start:] + train_data = train_data[train_data["solution"].str.contains("boxed", na=False)] + if filter_difficulty: + diff_dict = self.get_difficulty_dict(source, start, end) + train_data = train_data[train_data["problem"].map(diff_dict).apply(lambda x: x >= args.math_difficulty_lower_bound and x <= args.math_difficulty_upper_bound)] + return train_data + + def process_remaining_data(self, train_data, results): + return [row.to_dict() for _, row in train_data.iterrows() if str(row["problem"]) not in results] \ No newline at end of file diff --git a/skythought/tools/tasks/taco/pyext2.py b/skythought/tools/tasks/taco/pyext2.py index 636b10a..1bff5e4 100644 --- a/skythought/tools/tasks/taco/pyext2.py +++ b/skythought/tools/tasks/taco/pyext2.py @@ -1,4 +1,4 @@ -''' +""" Copyright (C) 2014 Ryan Gonzalez @@ -18,37 +18,60 @@ COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" g_backup = globals().copy() -__version__ = '0.7' - -__all__ = ['overload', 'RuntimeModule', 'switch', 'tail_recurse', 'copyfunc', 'set_docstring', 'annotate', 'safe_unpack', 'modify_function', 'assign', 'fannotate', 'compare_and_swap', 'is_main', 'call_if_main', 'run_main'] - -import sys, inspect, types - -def __targspec(func, specs, attr='__orig_arg__'): - if hasattr(func, '__is_overload__') and func.__is_overload__: +__version__ = "0.7" + +__all__ = [ + "overload", + "RuntimeModule", + "switch", + "tail_recurse", + "copyfunc", + "set_docstring", + "annotate", + "safe_unpack", + "modify_function", + "assign", + "fannotate", + "compare_and_swap", + "is_main", + "call_if_main", + "run_main", +] + +import inspect +import sys +import types + + +def __targspec(func, specs, attr="__orig_arg__"): + if hasattr(func, "__is_overload__") and func.__is_overload__: return getattr(func, attr) return specs(func) + def set_docstring(doc): - '''A simple decorator to set docstrings. + """A simple decorator to set docstrings. + + :param doc: The docstring to tie to the function. - :param doc: The docstring to tie to the function. + Example:: - Example:: + @set_docstring('This is a docstring') + def myfunc(x): + pass""" - @set_docstring('This is a docstring') - def myfunc(x): - pass''' def _wrap(f): f.__doc__ = doc return f + return _wrap -__modify_function_doc = ''' + +__modify_function_doc = """ Creates a copy of a function, changing its attributes. :param globals: Will be added to the function's globals. @@ -62,63 +85,100 @@ def _wrap(f): :param closure: The new function closure. Set to ``None`` to use the function's original closure. .. warning:: This function can be potentially dangerous. -''' +""" + def copyfunc(f): - '''Copies a funcion. + """Copies a funcion. + + :param f: The function to copy. - :param f: The function to copy. + :return: The copied function. - :return: The copied function. + .. deprecated:: 0.4 + Use :func:`modify_function` instead. + """ + return modify_function(f) - .. deprecated:: 0.4 - Use :func:`modify_function` instead. - ''' - return modify_function(f) if sys.version_info.major == 3: + @set_docstring(__modify_function_doc) - def modify_function(f, globals={}, name=None, code=None, defaults=None, - closure=None): - if code is None: code = f.__code__ - if name is None: name = f.__name__ - if defaults is None: defaults = f.__defaults__ - if closure is None: closure = f.__closure__ - newf = types.FunctionType(code, dict(f.__globals__, **globals), name=name, - argdefs=defaults, closure=closure) + def modify_function( + f, globals={}, name=None, code=None, defaults=None, closure=None + ): + if code is None: + code = f.__code__ + if name is None: + name = f.__name__ + if defaults is None: + defaults = f.__defaults__ + if closure is None: + closure = f.__closure__ + newf = types.FunctionType( + code, + dict(f.__globals__, **globals), + name=name, + argdefs=defaults, + closure=closure, + ) newf.__dict__.update(f.__dict__) return newf + def argspec(f): return inspect.getfullargspec(f) + ofullargspec = inspect.getfullargspec + def _fullargspec(func): return __targspec(func, ofullargspec) + inspect.getfullargspec = _fullargspec - def _exec(m,g): exec(m,g) + + def _exec(m, g): + exec(m, g) + else: + @set_docstring(__modify_function_doc) - def modify_function(f, globals={}, name=None, code=None, defaults=None, - closure=None): - if code is None: code = f.func_code - if name is None: name = f.__name__ - if defaults is None: defaults = f.func_defaults - if closure is None: closure = f.func_closure - newf = types.FunctionType(code, dict(f.func_globals, **globals), name=name, - argdefs=defaults, closure=closure) + def modify_function( + f, globals={}, name=None, code=None, defaults=None, closure=None + ): + if code is None: + code = f.func_code + if name is None: + name = f.__name__ + if defaults is None: + defaults = f.func_defaults + if closure is None: + closure = f.func_closure + newf = types.FunctionType( + code, + dict(f.func_globals, **globals), + name=name, + argdefs=defaults, + closure=closure, + ) newf.__dict__.update(f.__dict__) return newf + def argspec(f): return inspect.getargspec(f) - eval(compile('def _exec(m,g): exec m in g', '', 'exec')) + + eval(compile("def _exec(m,g): exec m in g", "", "exec")) + def _gettypes(args): return tuple(map(type, args)) + oargspec = inspect.getargspec + def _argspec(func): return __targspec(func, oargspec) + inspect.getargspec = _argspec try: @@ -128,50 +188,60 @@ def _argspec(func): else: # Replace IPython's argspec oipyargspec = IPython.core.oinspect.getargspec + def _ipyargspec(func): - return __targspec(func, oipyargspec, '__orig_arg_ipy__') + return __targspec(func, oipyargspec, "__orig_arg_ipy__") + IPython.core.oinspect.getargspec = _ipyargspec + class overload(object): - '''Simple function overloading in Python.''' + """Simple function overloading in Python.""" + _items = {} _types = {} + @classmethod def argc(self, argc=None): - '''Overloads a function based on the specified argument count. + """Overloads a function based on the specified argument count. - :param argc: The argument count. Defaults to ``None``. If ``None`` is given, automatically compute the argument count from the given function. + :param argc: The argument count. Defaults to ``None``. If ``None`` is given, automatically compute the argument count from the given function. - .. note:: + .. note:: - Keyword argument counts are NOT checked! In addition, when the argument count is automatically calculated, the keyword argument count is also ignored! + Keyword argument counts are NOT checked! In addition, when the argument count is automatically calculated, the keyword argument count is also ignored! - Example:: + Example:: - @overload.argc() - def func(a): - print 'Function 1 called' + @overload.argc() + def func(a): + print 'Function 1 called' - @overload.argc() - def func(a, b): - print 'Function 2 called' + @overload.argc() + def func(a, b): + print 'Function 2 called' - func(1) # Calls first function - func(1, 2) # Calls second function - func() # Raises error - ''' + func(1) # Calls first function + func(1, 2) # Calls second function + func() # Raises error + """ # Python 2 UnboundLocalError fix - argc = {'argc': argc} + argc = {"argc": argc} + def _wrap(f): def _newf(*args, **kwargs): if len(args) not in self._items[f.__name__]: - raise TypeError("No overload of function '%s' that takes %d args" % (f.__name__, len(args))) + raise TypeError( + "No overload of function '%s' that takes %d args" + % (f.__name__, len(args)) + ) return self._items[f.__name__][len(args)](*args, **kwargs) + if f.__name__ not in self._items: self._items[f.__name__] = {} - if argc['argc'] is None: - argc['argc'] = len(argspec(f).args) - self._items[f.__name__][argc['argc']] = f + if argc["argc"] is None: + argc["argc"] = len(argspec(f).args) + self._items[f.__name__][argc["argc"]] = f _newf.__name__ = f.__name__ _newf.__doc__ = f.__doc__ _newf.__is_overload__ = True @@ -179,54 +249,68 @@ def _newf(*args, **kwargs): if IPython: _newf.__orig_arg_ipy__ = IPython.core.oinspect.getargspec(f) return _newf + return _wrap + @classmethod def args(self, *argtypes, **kw): - '''Overload a function based on the specified argument types. + """Overload a function based on the specified argument types. - :param argtypes: The argument types. If None is given, get the argument types from the function annotations(Python 3 only) - :param kw: Can only contain 1 argument, `is_cls`. If True, the function is assumed to be part of a class. + :param argtypes: The argument types. If None is given, get the argument types from the function annotations(Python 3 only) + :param kw: Can only contain 1 argument, `is_cls`. If True, the function is assumed to be part of a class. - Example:: + Example:: - @overload.args(str) - def func(s): - print 'Got string' + @overload.args(str) + def func(s): + print 'Got string' - @overload.args(int, str) - def func(i, s): - print 'Got int and string' + @overload.args(int, str) + def func(i, s): + print 'Got int and string' - @overload.args() - def func(i:int): # A function annotation example - print 'Got int' + @overload.args() + def func(i:int): # A function annotation example + print 'Got int' - func('s') - func(1) - func(1, 's') - func(True) # Raises error - ''' + func('s') + func(1) + func(1, 's') + func(True) # Raises error + """ # Python 2 UnboundLocalError fix...again! - argtypes = {'args': tuple(argtypes)} + argtypes = {"args": tuple(argtypes)} + def _wrap(f): def _newf(*args): if len(kw) == 0: cargs = args - elif len(kw) == 1 and 'is_cls' in kw and kw['is_cls']: + elif len(kw) == 1 and "is_cls" in kw and kw["is_cls"]: cargs = args[1:] else: - raise ValueError('Invalid keyword args specified') + raise ValueError("Invalid keyword args specified") if _gettypes(cargs) not in self._types[f.__name__]: - raise TypeError("No overload of function '%s' that takes '%s' types and %d arg(s)" % (f.__name__, _gettypes(cargs), len(cargs))) + raise TypeError( + "No overload of function '%s' that takes '%s' types and %d arg(s)" + % (f.__name__, _gettypes(cargs), len(cargs)) + ) return self._types[f.__name__][_gettypes(cargs)](*args) + if f.__name__ not in self._types: self._types[f.__name__] = {} - if len(argtypes['args']) == 1 and argtypes['args'][0] is None: + if len(argtypes["args"]) == 1 and argtypes["args"][0] is None: aspec = argspec(f) - argtypes['args'] = tuple(map(lambda x: x[1], sorted( - aspec.annotations.items(), key=lambda x: aspec.args.index(x[0])))) - self._types[f.__name__][argtypes['args']] = f + argtypes["args"] = tuple( + map( + lambda x: x[1], + sorted( + aspec.annotations.items(), + key=lambda x: aspec.args.index(x[0]), + ), + ) + ) + self._types[f.__name__][argtypes["args"]] = f _newf.__name__ = f.__name__ _newf.__doc__ = f.__doc__ _newf.__is_overload__ = True @@ -234,119 +318,146 @@ def _newf(*args): if IPython: _newf.__orig_arg_ipy__ = IPython.core.oinspect.getargspec(f) return _newf + return _wrap + class _RuntimeModule(object): - 'Create a module object at runtime and insert it into sys.path. If called, same as :py:func:`from_objects`.' + "Create a module object at runtime and insert it into sys.path. If called, same as :py:func:`from_objects`." + def __call__(self, *args, **kwargs): return self.from_objects(*args, **kwargs) + @staticmethod @overload.argc(1) def from_objects(module_name_for_code_eval, **d): - return _RuntimeModule.from_objects(module_name_for_code_eval, '', **d) + return _RuntimeModule.from_objects(module_name_for_code_eval, "", **d) + @staticmethod @overload.argc(2) def from_objects(module_name_for_code_eval, docstring, **d): - '''Create a module at runtime from `d`. + """Create a module at runtime from `d`. - :param name: The module name. + :param name: The module name. - :param docstring: Optional. The module's docstring. + :param docstring: Optional. The module's docstring. - :param \*\*d: All the keyword args, mapped from name->value. + :param \*\*d: All the keyword args, mapped from name->value. - Example: ``RuntimeModule.from_objects('name', 'doc', a=1, b=2)``''' + Example: ``RuntimeModule.from_objects('name', 'doc', a=1, b=2)``""" module = types.ModuleType(module_name_for_code_eval, docstring) module.__dict__.update(d) - module.__file__ = '' + module.__file__ = "" sys.modules[module_name_for_code_eval] = module return module + @staticmethod @overload.argc(2) def from_string(module_name_for_code_eval, s): - return _RuntimeModule.from_string(module_name_for_code_eval, '', s) + return _RuntimeModule.from_string(module_name_for_code_eval, "", s) + @staticmethod @overload.argc(3) def from_string(module_name_for_code_eval, docstring, s): - '''Create a module at runtime from `s``. + """Create a module at runtime from `s``. - :param name: The module name. + :param name: The module name. - :param docstring: Optional. The module docstring. + :param docstring: Optional. The module docstring. - :param s: A string containing the module definition.''' + :param s: A string containing the module definition.""" g = {} _exec(s, g) - return _RuntimeModule.from_objects(module_name_for_code_eval, docstring, **dict(filter(lambda x: x[0] not in g_backup, g.items()))) + return _RuntimeModule.from_objects( + module_name_for_code_eval, + docstring, + **dict(filter(lambda x: x[0] not in g_backup, g.items())) + ) + RuntimeModule = _RuntimeModule() + class CaseObject(object): - 'The object returned by a switch statement. When called, it will return True if the given argument equals its value, else False. It can be called with multiple parameters, in which case it checks if its value equals any of the arguments.' + "The object returned by a switch statement. When called, it will return True if the given argument equals its value, else False. It can be called with multiple parameters, in which case it checks if its value equals any of the arguments." + def __init__(self, value): self.value = value self.did_match = False self.did_pass = False + def __call__(self, *args): - if assign('res', not self.did_pass and any([self.value == rhs for rhs in args])): + if assign( + "res", not self.did_pass and any([self.value == rhs for rhs in args]) + ): self.did_match = True return res + def quit(self): - 'Forces all other calls to return False. Equilavent of a ``break`` statement.' + "Forces all other calls to return False. Equilavent of a ``break`` statement." self.did_pass = True + def default(self): "Executed if quit wasn't called." return not self.did_match and not self.did_pass + def __iter__(self): yield self + def __enter__(self): return self + def __exit__(self, *args): pass + def switch(value): - '''A Python switch statement implementation that is used with a ``with`` statement. + """A Python switch statement implementation that is used with a ``with`` statement. - :param value: The value to "switch". + :param value: The value to "switch". - ``with`` statement example:: + ``with`` statement example:: - with switch('x'): - if case(1): print 'Huh?' - if case('x'): print 'It works!!!' + with switch('x'): + if case(1): print 'Huh?' + if case('x'): print 'It works!!!' - .. warning:: If you modify a variable named "case" in the same scope that you use the ``with`` statement version, you will get an UnboundLocalError. The soluction is to use ``with switch('x') as case:`` instead of ``with switch('x'):``.''' + .. warning:: If you modify a variable named "case" in the same scope that you use the ``with`` statement version, you will get an UnboundLocalError. The soluction is to use ``with switch('x') as case:`` instead of ``with switch('x'):``. + """ res = CaseObject(value) - inspect.stack()[1][0].f_globals['case'] = res + inspect.stack()[1][0].f_globals["case"] = res return res + def tail_recurse(spec=None): - '''Remove tail recursion from a function. + """Remove tail recursion from a function. - :param spec: A function that, when given the arguments, returns a bool indicating whether or not to exit. If ``None,`` tail recursion is always called unless the function returns a value. + :param spec: A function that, when given the arguments, returns a bool indicating whether or not to exit. If ``None,`` tail recursion is always called unless the function returns a value. - .. note:: + .. note:: - This function has a slight overhead that is noticable when using timeit. Only use it if the function has a possibility of going over the recursion limit. + This function has a slight overhead that is noticable when using timeit. Only use it if the function has a possibility of going over the recursion limit. - .. warning:: + .. warning:: - This function will BREAK any code that either uses any recursion other than tail recursion or calls itself multiple times. For example, ``def x(): return x()+1`` will fail. + This function will BREAK any code that either uses any recursion other than tail recursion or calls itself multiple times. For example, ``def x(): return x()+1`` will fail. - Example:: + Example:: - @tail_recurse() - def add(a, b): - if a == 0: return b - return add(a-1, b+1) + @tail_recurse() + def add(a, b): + if a == 0: return b + return add(a-1, b+1) + + add(10000000, 1) # Doesn't max the recursion limit. + """ - add(10000000, 1) # Doesn't max the recursion limit. - ''' def _wrap(f): class TailRecursion(Exception): def __init__(self, args, kwargs): self.args = args self.kwargs = kwargs + def _newf(*args, **kwargs): if inspect.stack()[1][3] == f.__name__: if (spec and spec(args)) or not spec: @@ -360,122 +471,137 @@ def _newf(*args, **kwargs): continue else: return res + _newf.__doc__ = f.__doc__ return _newf + return _wrap + def annotate(*args, **kwargs): - '''Set function annotations using decorators. + """Set function annotations using decorators. - :param args: This is a list of annotations for the function, in the order of the function's parameters. For example, ``annotate('Annotation 1', 'Annotation 2')`` will set the annotations of parameter 1 of the function to ``Annotation 1``. + :param args: This is a list of annotations for the function, in the order of the function's parameters. For example, ``annotate('Annotation 1', 'Annotation 2')`` will set the annotations of parameter 1 of the function to ``Annotation 1``. - :param kwargs: This is a mapping of argument names to annotations. Note that these are applied *after* the argument list, so any args set that way will be overriden by this mapping. If there is a key named `ret`, that will be the annotation for the function's return value. + :param kwargs: This is a mapping of argument names to annotations. Note that these are applied *after* the argument list, so any args set that way will be overriden by this mapping. If there is a key named `ret`, that will be the annotation for the function's return value. + + .. deprecated:: 0.5 + Use :func:`fannotate` instead.""" - .. deprecated:: 0.5 - Use :func:`fannotate` instead. -''' def _wrap(f): - if not hasattr(f, '__annotations__'): + if not hasattr(f, "__annotations__"): f.__annotations__ = {} - if 'ret' in kwargs: - f.__annotations__['return'] = kwargs.pop('ret') + if "ret" in kwargs: + f.__annotations__["return"] = kwargs.pop("ret") f.__annotations__.update(dict(zip(argspec(f).args, args))) f.__annotations__.update(kwargs) return f + return _wrap + def fannotate(*args, **kwargs): - '''Set function annotations using decorators. + """Set function annotations using decorators. + + :param \*args: The first positional argument is used for the function's return value; all others are discarded. - :param \*args: The first positional argument is used for the function's return value; all others are discarded. + :param \**kwargs: This is a mapping of argument names to annotations. - :param \**kwargs: This is a mapping of argument names to annotations. + Example:: - Example:: + @fannotate('This for the return value', a='Parameter a', b='Parameter b') + def x(a, b): + pass - @fannotate('This for the return value', a='Parameter a', b='Parameter b') - def x(a, b): - pass + """ - ''' def _wrap(f): - if not hasattr(f, '__annotations__'): + if not hasattr(f, "__annotations__"): f.__annotations__ = {} if len(args) >= 1: - f.__annotations__['return'] = args[0] + f.__annotations__["return"] = args[0] f.__annotations__.update(kwargs) return f + return _wrap + def safe_unpack(seq, ln, fill=None): - '''Safely unpack a sequence to length `ln`, without raising ValueError. Based on Lua's method of unpacking. Empty values will be filled in with `fill`, while any extra values will be cut off. + """Safely unpack a sequence to length `ln`, without raising ValueError. Based on Lua's method of unpacking. Empty values will be filled in with `fill`, while any extra values will be cut off. - :param seq: The sequence to unpack. + :param seq: The sequence to unpack. - :param ln: The expected length of the sequence. + :param ln: The expected length of the sequence. - :param fill: The value to substitute if the sequence is too small. Defaults to ``None``. + :param fill: The value to substitute if the sequence is too small. Defaults to ``None``. - Example:: + Example:: - s = 'a:b' - a, b = safe_unpack(s.split(':'), 2) - # a = 'a' - # b = 'b' - s = 'a' - a, b = safe_unpack(s.split(':'), 2) - # a = 'a' - # b = None''' + s = 'a:b' + a, b = safe_unpack(s.split(':'), 2) + # a = 'a' + # b = 'b' + s = 'a' + a, b = safe_unpack(s.split(':'), 2) + # a = 'a' + # b = None""" if len(seq) > ln: return seq[:ln] elif len(seq) < ln: - return seq + type(seq)([fill]*(ln-len(seq))) + return seq + type(seq)([fill] * (ln - len(seq))) else: return seq + def assign(varname, value): - '''Assign `value` to `varname` and return it. If `varname` is an attribute and the instance name it belongs to is not defined, a NameError is raised. - This can be used to emulate assignment as an expression. For example, this:: + """Assign `value` to `varname` and return it. If `varname` is an attribute and the instance name it belongs to is not defined, a NameError is raised. + This can be used to emulate assignment as an expression. For example, this:: - if assign('x', 7): ... + if assign('x', 7): ... - is equilavent to this C code:: + is equilavent to this C code:: - if (x = 7) ... + if (x = 7) ... - .. warning:: + .. warning:: - When assigning an attribute, the instance it belongs to MUST be declared as global prior to the assignment. Otherwise, the assignment will not work. - ''' + When assigning an attribute, the instance it belongs to MUST be declared as global prior to the assignment. Otherwise, the assignment will not work. + """ fd = inspect.stack()[1][0].f_globals - if '.' not in varname: + if "." not in varname: fd[varname] = value else: - vsplit = list(map(str.strip, varname.split('.'))) + vsplit = list(map(str.strip, varname.split("."))) if vsplit[0] not in fd: - raise NameError('Unknown object: %s'%vsplit[0]) + raise NameError("Unknown object: %s" % vsplit[0]) base = fd[vsplit[0]] for x in vsplit[1:-1]: base = getattr(base, x) setattr(base, vsplit[-1], value) return value + def is_main(frame=1): "Return if the caller is main. Equilavent to ``__name__ == '__main__'``." - return inspect.stack()[frame][0].f_globals['__name__'] == '__main__' + return inspect.stack()[frame][0].f_globals["__name__"] == "__main__" + def _call_if_main(frame, f, args): - if is_main(frame): return f(*args) + if is_main(frame): + return f(*args) + -def call_if_main(f,*args): +def call_if_main(f, *args): "Call the `f` with `args` if the caller's module is main." - return _call_if_main(3,f,args) + return _call_if_main(3, f, args) -def run_main(f,*args): + +def run_main(f, *args): "Call `f` with the `args` and terminate the program with its return code if the caller's module is main." - sys.exit(_call_if_main(3,f,args)) + sys.exit(_call_if_main(3, f, args)) + def compare_and_swap(var, compare, new): "If `var` is equal to `compare`, set it to `new`." - if assign('v', inspect.stack()[1][0].f_globals)[var] == compare: + if assign("v", inspect.stack()[1][0].f_globals)[var] == compare: v[var] = new diff --git a/skythought/tools/tasks/taco/taco_handler.py b/skythought/tools/tasks/taco/taco_handler.py index e69de29..917207d 100644 --- a/skythought/tools/tasks/taco/taco_handler.py +++ b/skythought/tools/tasks/taco/taco_handler.py @@ -0,0 +1,124 @@ +import json +import multiprocessing +from multiprocessing import Manager + +import numpy as np +from datasets import load_dataset + +from tasks.taco.taco_util import run_test as taco_run_test +from util.common import has_code + +from ..common import TaskHandler + + +class TACOTaskHandler(TaskHandler): + @staticmethod + def get_question_key(): + return "question" + + @staticmethod + def generate_prompt(prompt, starter_code=None, fn_name=None): + _input = "\nQUESTION:\n" + _input += prompt + if starter_code: + _input += starter_code + if (not fn_name) and (not starter_code): + call_format = "\nUse Standard Input format" + _input += call_format + else: + call_format = "\nUse Call-Based format" + _input += call_format + _input += "\nANSWER:\n" + + return _input + + def check_correctness(self, problem, generation): + TIME_OUT = 300 + + def _temp_run(problem, generation, debug, result): + try: + result.append(taco_run_test(problem, test=generation, debug=debug)) + except Exception as e: + print(f"Error in _temp_run: {e}") + + manager = Manager() + result = manager.list() + p = multiprocessing.Process( + target=_temp_run, args=(problem, generation, False, result) + ) + p.start() + p.join(timeout=TIME_OUT + 1) + if p.is_alive(): + p.kill() + return bool(result and np.all(result[0])) + + def update_results(self, problem, response): + if not isinstance(response, str): + response = response.outputs[0].text.strip() + # Initialize the response structure + response_entry = { + "content": response, + "correctness": None, + "reason": None, + } + code_filter_result = has_code(response) + if len(code_filter_result) == 0: + response_entry["correctness"] = False + response_entry["reason"] = "Does not contain code component." + else: + last_code = code_filter_result[-1] + curr_res = self.check_correctness(problem, generation=last_code) + if curr_res: + response_entry["correctness"] = True + response_entry["reason"] = "" + else: + response_entry["correctness"] = False + response_entry["reason"] = "Code is incorrect." + + return response_entry + + def make_conversations(self, data, system_prompt, model=None): + conversations = [] + for idx, problem in enumerate(data): + starter_code = ( + None if len(problem["starter_code"]) == 0 else problem["starter_code"] + ) + try: + input_outpout = json.loads(problem["input_output"]) + fn_name = ( + None + if not input_outpout.get("fn_name") + else input_outpout["fn_name"] + ) + except ValueError: + fn_name = None + prompt_text = self.generate_prompt( + problem["question"], starter_code, fn_name + ) + conversations.append( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt_text}, + ] + ) + return conversations + + def load_and_filter_dataset( + self, start, end, split="train", source=None, filter_difficulty=False, args=None + ): + dataset = load_dataset("BAAI/TACO", "ALL", trust_remote_code=True) + train_data = dataset[split].to_pandas() + if not filter_difficulty: + return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] + return ( + train_data.query("difficulty == @source").iloc[start:end] + if end > 0 + else train_data.query("difficulty == @source").iloc[start:] + ) + + def process_remaining_data(self, train_data, results): + return [ + row.to_dict() + for _, row in train_data.iterrows() + if str(row["question"]) not in results + ] diff --git a/skythought/tools/tasks/taco/taco_util.py b/skythought/tools/tasks/taco/taco_util.py index 25715dd..68148b5 100644 --- a/skythought/tools/tasks/taco/taco_util.py +++ b/skythought/tools/tasks/taco/taco_util.py @@ -1,26 +1,29 @@ # modifed from https://github.com/hendrycks/apps/blob/main/eval/testing_util.py to fix some evaluation bugs and add instructions -from .pyext2 import RuntimeModule +import faulthandler +import json +import os import signal -import numpy as np +import subprocess +import sys +import tempfile # used for debugging to time steps from datetime import datetime - -import os, sys, json -import faulthandler - -import subprocess -import tempfile -import inspect from enum import Enum -from unittest.mock import patch, mock_open from io import StringIO +from unittest.mock import mock_open, patch + +import numpy as np + +from .pyext2 import RuntimeModule + class CODE_TYPE(Enum): call_based = 0 standard_input = 1 + class Capturing(list): def __enter__(self): self._stdout = sys.stdout @@ -28,31 +31,44 @@ def __enter__(self): # Make closing the StringIO a no-op self._stringio.close = lambda x: 1 return self + def __exit__(self, *args): self.extend(self._stringio.getvalue().splitlines()) - del self._stringio # free up some memory + del self._stringio # free up some memory sys.stdout = self._stdout + # to run the solution files we're using a timing based approach -import signal # stuff for setting up signal timer class TimeoutException(Exception): pass + + def timeout_handler(signum, frame): - print(f"alarm went off") + print("alarm went off") # return raise TimeoutException + + signal.signal(signal.SIGALRM, timeout_handler) TIMEOUT = 4 # seconds -EXECUTION_RESULTS = {1: "passed", 0: "false", -1: "timeout", -2: "runtime_error", -3: "returncode:{code}", -4: "compile_error"} +EXECUTION_RESULTS = { + 1: "passed", + 0: "false", + -1: "timeout", + -2: "runtime_error", + -3: "returncode:{code}", + -4: "compile_error", +} + def run_test(sample, test=None, debug=False): """ if test(generated_code) is not None it'll try to run the code. otherwise it'll just return an input and output pair. """ - + if debug: print(f"start = {datetime.now().time()}") @@ -60,7 +76,7 @@ def run_test(sample, test=None, debug=False): in_outs = json.loads(sample["input_output"]) except ValueError: in_outs = None - + if in_outs: if in_outs.get("fn_name") is None: which_type = CODE_TYPE.standard_input # Standard input @@ -87,27 +103,58 @@ def run_test(sample, test=None, debug=False): print(f"loading test code = {datetime.now().time()}") if which_type == CODE_TYPE.call_based: synthesized_code = synthesize_cb_code(test, debug) - method_func = compile_and_get_func(synthesized_code, which_type, method_name, timeout=TIMEOUT, debug=debug) + method_func = compile_and_get_func( + synthesized_code, which_type, method_name, timeout=TIMEOUT, debug=debug + ) elif which_type == CODE_TYPE.standard_input: synthesized_code, exec_code = synthesize_std_code(test, debug) - method_func = compile_and_get_func(synthesized_code, which_type, method_name, timeout=TIMEOUT, debug=debug) + method_func = compile_and_get_func( + synthesized_code, which_type, method_name, timeout=TIMEOUT, debug=debug + ) if not method_func: results.append(-2) return results else: if which_type == CODE_TYPE.call_based: # Call-based - detail_results, debug_infos = execute_cb_code(method_func, inputs_list, outputs_list, timeout=TIMEOUT, early_stop=True, debug=debug) + detail_results, debug_infos = execute_cb_code( + method_func, + inputs_list, + outputs_list, + timeout=TIMEOUT, + early_stop=True, + debug=debug, + ) elif which_type == CODE_TYPE.standard_input: - detail_results = execute_std_code(method_func, exec_code, inputs_list, outputs_list, timeout=TIMEOUT, early_stop=True, debug=debug) - debug_infos = detail_results.get('debug', None) - detail_results = {k:v for k, v in detail_results.items() if k!='debug'} - if set(detail_results.values()) == {(False, 'returncode:1')}: - detail_results = execute_std_code(method_func, synthesized_code+'\ncode()\n', inputs_list, outputs_list, timeout=TIMEOUT, early_stop=True, debug=debug) - + detail_results = execute_std_code( + method_func, + exec_code, + inputs_list, + outputs_list, + timeout=TIMEOUT, + early_stop=True, + debug=debug, + ) + debug_infos = detail_results.get("debug", None) + detail_results = { + k: v for k, v in detail_results.items() if k != "debug" + } + if set(detail_results.values()) == {(False, "returncode:1")}: + detail_results = execute_std_code( + method_func, + synthesized_code + "\ncode()\n", + inputs_list, + outputs_list, + timeout=TIMEOUT, + early_stop=True, + debug=debug, + ) + if isinstance(detail_results, list): if len(detail_results) == 1: detail_results = detail_results * len(inputs_list) - detail_results = dict(zip([i for i in range(len(inputs_list))], detail_results)) + detail_results = dict( + zip([i for i in range(len(inputs_list))], detail_results) + ) for test_id, test_result in detail_results.items(): if test_result[1] == "passed": results.append(True) @@ -119,28 +166,30 @@ def run_test(sample, test=None, debug=False): results.append(-3) return results + def process_input_output(inputs, outputs): # JSON forces dictionaries to have string keys; this undoes this (assuming a singleton list) try: if isinstance(inputs[0], dict): - inputs = [{int(k): v for k,v in inputs[0].items()}] + inputs = [{int(k): v for k, v in inputs[0].items()}] except: True - + try: if isinstance(outputs, dict): - outputs = [{int(k): v for k,v in outputs.items()}] + outputs = [{int(k): v for k, v in outputs.items()}] except: True try: if isinstance(outputs[0], dict): - outputs = [{int(k): v for k,v in outputs[0].items()}] + outputs = [{int(k): v for k, v in outputs[0].items()}] except: True - + return inputs, outputs + def compile_and_get_func(program, which_type, method_name, timeout, debug): try: signal.alarm(timeout) @@ -155,12 +204,12 @@ def compile_and_get_func(program, which_type, method_name, timeout, debug): if debug: print(f"compilation error = {e}") return False - + if which_type == CODE_TYPE.call_based: assert isinstance(method_name, str) else: method_name = "code" - + try: signal.alarm(timeout) method = getattr(tmp, method_name) # get_attr second arg must be str @@ -173,6 +222,7 @@ def compile_and_get_func(program, which_type, method_name, timeout, debug): return False return method + def synthesize_cb_code(raw_code, debug=False): sol = "import sys\nimport time\nimport itertools\nfrom itertools import accumulate, product, permutations, combinations\nimport collections\nfrom collections import Counter, OrderedDict, deque, defaultdict, ChainMap\nfrom functools import lru_cache\nimport math\nfrom math import sqrt, sin, cos, tan, ceil, fabs, floor, gcd, exp, log, log2\nimport fractions\nfrom typing import List, Tuple\nimport numpy as np\nimport random\nimport heapq\nfrom heapq import *\n" if debug: @@ -180,31 +230,33 @@ def synthesize_cb_code(raw_code, debug=False): sol += raw_code return sol + def synthesize_std_code(raw_code, debug=False): normal_import_lines = "import sys\nimport time\nimport itertools\nfrom itertools import accumulate, product, permutations, combinations\nimport collections\nfrom collections import Counter, OrderedDict, deque, defaultdict, ChainMap\nfrom functools import lru_cache\nimport math\nfrom math import sqrt, sin, cos, tan, ceil, fabs, floor, gcd, exp, log, log2\nimport fractions\nfrom typing import List, Tuple\nimport numpy as np\nimport random\nimport heapq\nfrom heapq import *\n" if debug: print(f"loading test code = {datetime.now().time()}") - - sol = "" # code for compile - sol2 = "" # code for execute + + sol = "" # code for compile + sol2 = "" # code for execute tmp_test = raw_code.split("\n") # define the code line type, 1 for import lines, 2 for import * lines with indent, 0 for normal codes - code_types = [] - + code_types = [] for x in tmp_test: - if 'import *' in x: + if "import *" in x: code_types.append(2) elif x.startswith("from ") or x.startswith("import "): - code_types.append(1) + code_types.append(1) else: code_types.append(0) - + started = False - special_import_lines = [i.lstrip('\t') for idx, i in enumerate(tmp_test) if code_types[idx]==2] - special_import_lines = '\n'.join(special_import_lines) + special_import_lines = [ + i.lstrip("\t") for idx, i in enumerate(tmp_test) if code_types[idx] == 2 + ] + special_import_lines = "\n".join(special_import_lines) for idx, i in enumerate(tmp_test): code_type = code_types[idx] @@ -223,17 +275,17 @@ def synthesize_std_code(raw_code, debug=False): sol2 += f"{i}\n" if code_type < 2: if started: - sol += '\t' + sol += "\t" sol += f"{i}\n" - + if debug: print(f"sol = {sol}") print(f"sol2 = {sol2}") - + return sol, sol2 -def call_method(method, inputs): +def call_method(method, inputs): if isinstance(inputs, list): inputs = "\n".join(inputs) @@ -242,22 +294,26 @@ def call_method(method, inputs): # sys.setrecursionlimit(10000) # @patch('builtins.input', side_effect=inputs.split("\n")) - @patch('builtins.open', mock_open(read_data=inputs)) - @patch('sys.stdin', StringIO(inputs)) - @patch('sys.stdin.readline', lambda *args: next(inputs_line_iterator)) - @patch('sys.stdin.readlines', lambda *args: inputs.split("\n")) - @patch('sys.stdin.read', lambda *args: inputs) + @patch("builtins.open", mock_open(read_data=inputs)) + @patch("sys.stdin", StringIO(inputs)) + @patch("sys.stdin.readline", lambda *args: next(inputs_line_iterator)) + @patch("sys.stdin.readlines", lambda *args: inputs.split("\n")) + @patch("sys.stdin.read", lambda *args: inputs) # @patch('sys.stdout.write', print) def _inner_call_method(_method): try: return _method() - except SystemExit as e: + except SystemExit: pass finally: pass - return _inner_call_method(method) -def execute_cb_code(method, inputs_list, outputs_list, timeout, early_stop=True, debug=True): + return _inner_call_method(method) + + +def execute_cb_code( + method, inputs_list, outputs_list, timeout, early_stop=True, debug=True +): # Disable functionalities that can make destructive changes to the test. reliability_guard() results = [] @@ -267,7 +323,7 @@ def execute_cb_code(method, inputs_list, outputs_list, timeout, early_stop=True, debug_infos[index] = {} outputs = outputs_list[index] try: - signal.alarm(timeout) + signal.alarm(timeout) faulthandler.enable() exec_outputs = method(*inputs) signal.alarm(0) @@ -287,7 +343,7 @@ def execute_cb_code(method, inputs_list, outputs_list, timeout, early_stop=True, # ground truth sequences are not tuples if isinstance(exec_outputs, tuple): exec_outputs = list(exec_outputs) - + tmp_result = exec_outputs == outputs if isinstance(outputs, list) and outputs: tmp_result = tmp_result or (exec_outputs == outputs[0]) @@ -310,21 +366,33 @@ def execute_cb_code(method, inputs_list, outputs_list, timeout, early_stop=True, continue if debug: - print(f"outputs = {exec_outputs}, test outputs = {outputs}, inputs = {inputs}, {type(inputs)}, {tmp_result}") + print( + f"outputs = {exec_outputs}, test outputs = {outputs}, inputs = {inputs}, {type(inputs)}, {tmp_result}" + ) debug_infos[index] = { - 'inputs': inputs, - 'gt_outputs': outputs, - 'exec_outputs': exec_outputs - } + "inputs": inputs, + "gt_outputs": outputs, + "exec_outputs": exec_outputs, + } return results, debug_infos + def remove_tmp_files(): - tmp_files = ['input.txt', 'output.txt'] + tmp_files = ["input.txt", "output.txt"] for tmp_file in tmp_files: - if tmp_file in os.listdir('.'): + if tmp_file in os.listdir("."): os.remove(tmp_file) -def execute_std_code(method, synthesized_code, inputs_list, outputs_list, timeout, early_stop=False, debug=False): + +def execute_std_code( + method, + synthesized_code, + inputs_list, + outputs_list, + timeout, + early_stop=False, + debug=False, +): temp_program_path = create_temp_file(synthesized_code) if debug: print("Test program:", temp_program_path) @@ -332,7 +400,7 @@ def execute_std_code(method, synthesized_code, inputs_list, outputs_list, timeou assert len(inputs_list) == len(outputs_list) exec_results = {} if debug: - exec_results['debug'] = {} + exec_results["debug"] = {} for i, inputs in enumerate(inputs_list): remove_tmp_files() outputs = outputs_list[i] @@ -340,9 +408,15 @@ def execute_std_code(method, synthesized_code, inputs_list, outputs_list, timeou inputs = "\n".join(inputs) if isinstance(outputs, list): outputs = "\n".join(outputs) - + try: - result = subprocess.run(['python', temp_program_path], input=inputs, text=True, capture_output=True, timeout=timeout) + result = subprocess.run( + ["python", temp_program_path], + input=inputs, + text=True, + capture_output=True, + timeout=timeout, + ) exec_code = 999 except subprocess.TimeoutExpired: exec_code = -1 @@ -371,7 +445,7 @@ def execute_std_code(method, synthesized_code, inputs_list, outputs_list, timeou # exec_code = 1 # else: # exec_code = 0 - + # except: # exec_code = -3 if compare_std_results(result.stdout, outputs, debug): @@ -379,65 +453,83 @@ def execute_std_code(method, synthesized_code, inputs_list, outputs_list, timeou else: exec_code = 0 assert exec_code != -3 - exec_results[i] = (exec_code==1, EXECUTION_RESULTS[exec_code] if exec_code>-3 else EXECUTION_RESULTS[exec_code].format(result.returncode)) + exec_results[i] = ( + exec_code == 1, + EXECUTION_RESULTS[exec_code] + if exec_code > -3 + else EXECUTION_RESULTS[exec_code].format(result.returncode), + ) if exec_code >= 0: if debug: - print_debug_info(inputs=inputs, outputs=outputs, exec_outputs=result.stdout) - exec_results['debug'][i] = { - 'inputs': inputs, - 'gt_outputs': outputs, - 'exec_outputs': result.stdout + print_debug_info( + inputs=inputs, outputs=outputs, exec_outputs=result.stdout + ) + exec_results["debug"][i] = { + "inputs": inputs, + "gt_outputs": outputs, + "exec_outputs": result.stdout, } - if early_stop and exec_code<=0: + if early_stop and exec_code <= 0: break return exec_results + def print_debug_info(inputs, outputs, exec_outputs): nl = "\n" if not isinstance(inputs, list): - print(f"exec output = {exec_outputs}, test outputs = {outputs}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {exec_outputs == [outputs]}") + print( + f"exec output = {exec_outputs}, test outputs = {outputs}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {exec_outputs == [outputs]}" + ) else: - print(f"exec output = {exec_outputs}, test outputs = {outputs}, inputs = {inputs}, {type(inputs)}, {exec_outputs == [outputs]}") + print( + f"exec output = {exec_outputs}, test outputs = {outputs}, inputs = {inputs}, {type(inputs)}, {exec_outputs == [outputs]}" + ) + def create_temp_file(content): - with tempfile.NamedTemporaryFile(delete=False, mode='w', encoding='utf-8') as temp_file: + with tempfile.NamedTemporaryFile( + delete=False, mode="w", encoding="utf-8" + ) as temp_file: temp_file.write(content) temp_file_path = temp_file.name return temp_file_path + def compare_std_results(exec_outputs, outputs, debug=False): if stripped_string_compare(exec_outputs, outputs): return True - + if isinstance(exec_outputs, list): output_1 = "\n".join(exec_outputs) if stripped_string_compare(output_1, outputs): return True - + if isinstance(exec_outputs, list): output_2 = [o.lstrip().rstrip() for o in exec_outputs] output_2 = "\n".join(output_2) if stripped_string_compare(output_2, outputs): return True - + tmp_result = False # ground truth sequences are expressed as lists not tuples if isinstance(outputs, tuple): outputs = list(outputs) - + try: - tmp_result = (exec_outputs == [outputs]) + tmp_result = exec_outputs == [outputs] if isinstance(outputs, list): tmp_result = tmp_result or (exec_outputs == outputs) if isinstance(exec_outputs[0], str): - tmp_result = tmp_result or ([e.strip() for e in exec_outputs] == outputs) + tmp_result = tmp_result or ( + [e.strip() for e in exec_outputs] == outputs + ) except Exception as e: if debug: print(f"Failed check1 exception = {e}") pass if tmp_result: return True - + # try one more time without \n if isinstance(outputs, list): for tmp_index, i in enumerate(outputs): @@ -446,10 +538,10 @@ def compare_std_results(exec_outputs, outputs, debug=False): else: outputs = outputs.split("\n") outputs = list(filter(len, outputs)) - outputs = list(map(lambda x:x.strip(), outputs)) - + outputs = list(map(lambda x: x.strip(), outputs)) + try: - tmp_result = (exec_outputs == [outputs]) + tmp_result = exec_outputs == [outputs] if isinstance(outputs, list): tmp_result = tmp_result or (exec_outputs == outputs) except Exception as e: @@ -458,12 +550,12 @@ def compare_std_results(exec_outputs, outputs, debug=False): pass if tmp_result: return True - + # try by converting the output into a split up list too if isinstance(exec_outputs, list): exec_outputs = list(filter(len, exec_outputs)) try: - tmp_result = (exec_outputs == [outputs]) + tmp_result = exec_outputs == [outputs] if isinstance(outputs, list): tmp_result = tmp_result or (exec_outputs == outputs) except Exception as e: @@ -472,25 +564,28 @@ def compare_std_results(exec_outputs, outputs, debug=False): pass if tmp_result: return True - try: output_float = [float(e) for e in exec_outputs] gt_float = [float(e) for e in outputs] - tmp_result = tmp_result or ((len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float)) - except Exception as e: + tmp_result = tmp_result or ( + (len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float) + ) + except Exception: pass try: if isinstance(exec_outputs[0], list): output_float = [float(e) for e in exec_outputs[0]] gt_float = [float(e) for e in outputs[0]] - tmp_result = tmp_result or ((len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float)) - except Exception as e: + tmp_result = tmp_result or ( + (len(output_float) == len(gt_float)) + and np.allclose(output_float, gt_float) + ) + except Exception: pass if tmp_result: return True - if isinstance(outputs, list): for tmp_index, i in enumerate(outputs): outputs[tmp_index] = set(i.split()) @@ -498,48 +593,54 @@ def compare_std_results(exec_outputs, outputs, debug=False): outputs = set(outputs.split()) try: - tmp_result = (exec_outputs == outputs) + tmp_result = exec_outputs == outputs except Exception as e: if debug: print(f"Failed check4 exception = {e}") if tmp_result: return True - + # try by converting the output into a split up list too if isinstance(exec_outputs, list): for tmp_index, i in enumerate(exec_outputs): exec_outputs[tmp_index] = i.split() exec_outputs = list(filter(len, exec_outputs)) for tmp_index, i in enumerate(exec_outputs): - exec_outputs[tmp_index] = set(i) + exec_outputs[tmp_index] = set(i) else: exec_outputs = exec_outputs.split() exec_outputs = list(filter(len, exec_outputs)) exec_outputs = set(exec_outputs) try: - tmp_result = (set(frozenset(s) for s in exec_outputs) == set(frozenset(s) for s in outputs)) + tmp_result = set(frozenset(s) for s in exec_outputs) == set( + frozenset(s) for s in outputs + ) except Exception as e: if debug: print(f"Failed check5 exception = {e}") - + # if they are all numbers, round so that similar numbers are treated as identical try: - tmp_result = tmp_result or (set(frozenset(round(float(t),3) for t in s) for s in exec_outputs) ==\ - set(frozenset(round(float(t),3) for t in s) for s in outputs)) + tmp_result = tmp_result or ( + set(frozenset(round(float(t), 3) for t in s) for s in exec_outputs) + == set(frozenset(round(float(t), 3) for t in s) for s in outputs) + ) except Exception as e: if debug: print(f"Failed check6 exception = {e}") if tmp_result: return True - + return False + def stripped_string_compare(s1, s2): s1 = s1.lstrip().rstrip() s2 = s2.lstrip().rstrip() return s1 == s2 + def reliability_guard(maximum_memory_bytes=None): """ This disables various destructive functions and prevents the generated code @@ -555,10 +656,16 @@ def reliability_guard(maximum_memory_bytes=None): if maximum_memory_bytes is not None: import resource - resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)) - resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)) + resource.setrlimit( + resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes) + ) + resource.setrlimit( + resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes) + ) if not platform.uname().system == "Darwin": - resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)) + resource.setrlimit( + resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes) + ) faulthandler.disable() diff --git a/skythought/tools/upload_hub.py b/skythought/tools/upload_hub.py index c318a7f..8b4d052 100644 --- a/skythought/tools/upload_hub.py +++ b/skythought/tools/upload_hub.py @@ -9,7 +9,8 @@ import tempfile import torch -from transformers import AutoTokenizer, AutoModelForCausalLM +from transformers import AutoModelForCausalLM, AutoTokenizer + def upload_hub(model_path, hub_repo_id, component, private): if component == "all": @@ -42,4 +43,4 @@ def upload_hub(model_path, hub_repo_id, component, private): parser.add_argument("--private", action="store_true") args = parser.parse_args() - upload_hub(args.model_path, args.hub_repo_id, args.component, args.private) \ No newline at end of file + upload_hub(args.model_path, args.hub_repo_id, args.component, args.private) diff --git a/skythought/tools/util/common.py b/skythought/tools/util/common.py index efd2b36..1094cdf 100644 --- a/skythought/tools/util/common.py +++ b/skythought/tools/util/common.py @@ -1,11 +1,15 @@ import multiprocessing + class TimeoutException(Exception): """Custom exception for function timeout.""" + pass + def timeout(seconds): """Decorator to enforce a timeout on a function using multiprocessing.""" + def decorator(func): def wrapper(*args, **kwargs): # A queue to store the result or exception @@ -18,19 +22,25 @@ def target(queue, *args, **kwargs): except Exception as e: queue.put((False, e)) - process = multiprocessing.Process(target=target, args=(queue, *args), kwargs=kwargs) + process = multiprocessing.Process( + target=target, args=(queue, *args), kwargs=kwargs + ) process.start() process.join(seconds) if process.is_alive(): process.terminate() process.join() - raise TimeoutException(f"Function '{func.__name__}' timed out after {seconds} seconds!") + raise TimeoutException( + f"Function '{func.__name__}' timed out after {seconds} seconds!" + ) success, value = queue.get() if success: return value else: raise value + return wrapper - return decorator \ No newline at end of file + + return decorator diff --git a/skythought/tools/util/math_parsing_util.py b/skythought/tools/util/math_parsing_util.py index 4a503b7..798930f 100644 --- a/skythought/tools/util/math_parsing_util.py +++ b/skythought/tools/util/math_parsing_util.py @@ -3,15 +3,15 @@ """ import re -import regex -from word2number import w2n from math import isclose -from collections import defaultdict -from sympy import simplify, N -from sympy.parsing.sympy_parser import parse_expr -from sympy.parsing.latex import parse_latex +import regex from latex2sympy2 import latex2sympy +from sympy import N, simplify +from sympy.parsing.latex import parse_latex +from sympy.parsing.sympy_parser import parse_expr +from word2number import w2n + def convert_word_number(text: str) -> str: try: @@ -20,6 +20,7 @@ def convert_word_number(text: str) -> str: pass return text + def _fix_fracs(string): substrs = string.split("\\frac") new_str = substrs[0] @@ -73,6 +74,7 @@ def _fix_sqrt(string): _string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string) return _string + def strip_answer_string(string): string = str(string).strip() # linebreaks @@ -111,9 +113,10 @@ def strip_answer_string(string): def replace_match(match): word = match.group(1).lower() if convert_word_number(word) == word: - return match.group(0) + return match.group(0) else: - return convert_word_number(word) + return convert_word_number(word) + string = re.sub(r"\\text\{([a-zA-Z]+)\}", replace_match, string) # Before removing unit, check if the unit is squared (for surface area) @@ -223,18 +226,19 @@ def replace_match(match): if re.fullmatch(r"(\s*-?\d+\s*,)*\s*-?\d+\s*", string): # Split the string into a list of integers try: - integer_list = list(map(int, string.split(','))) + integer_list = list(map(int, string.split(","))) except: - integer_list = list(map(int, "-1,-1".split(','))) + integer_list = list(map(int, "-1,-1".split(","))) # Sort the list in ascending order sorted_list = sorted(integer_list) # Join the sorted list back into a comma-separated string - string = ','.join(map(str, sorted_list)) + string = ",".join(map(str, sorted_list)) return string + def extract_answer(pred_str, use_last_number=True): pred_str = pred_str.replace("\u043a\u0438", "") if "final answer is $" in pred_str and "$. I hope" in pred_str: @@ -292,6 +296,7 @@ def extract_answer(pred_str, use_last_number=True): pred = strip_answer_string(pred) return pred + def get_multiple_choice_answer(pred: str): tmp = re.findall(r"\b(A|B|C|D)\b", pred.upper()) if tmp: @@ -302,13 +307,14 @@ def get_multiple_choice_answer(pred: str): if len(pred) == 0: pred = "" else: - pred = pred[-1] + pred = pred[-1] # Remove the period at the end, again! pred = pred.rstrip(".").rstrip("/") return pred + def mmlu_pro_extract_answer(text): pattern = r"answer is \(?([A-J])\)?" match = re.search(pattern, text) @@ -316,7 +322,7 @@ def mmlu_pro_extract_answer(text): return match.group(1) else: # print("1st answer extract failed\n" + text) - match = re.search(r'.*[aA]nswer:\s*([A-J])', text) + match = re.search(r".*[aA]nswer:\s*([A-J])", text) if match: return match.group(1) else: @@ -326,6 +332,7 @@ def mmlu_pro_extract_answer(text): if match: return match.group(0) + def choice_answer_clean(pred: str): pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":") # Clean the answer based on the dataset @@ -427,7 +434,7 @@ def math_equal( prediction = str(prediction).strip() ## pmatrix (amps) - if "pmatrix" in prediction and not "pmatrix" in reference: + if "pmatrix" in prediction and "pmatrix" not in reference: reference = str_to_pmatrix(reference) ## deal with [], (), {} @@ -614,4 +621,4 @@ def _parse(s): except: pass - return False \ No newline at end of file + return False diff --git a/skythought/tools/util/model_utils.py b/skythought/tools/util/model_utils.py index fc46014..066d00a 100644 --- a/skythought/tools/util/model_utils.py +++ b/skythought/tools/util/model_utils.py @@ -64,7 +64,7 @@ "PRIME-RL/Eurus-2-7B-PRIME": "Eurus-2-7B-PRIME", "NovaSky-AI/Sky-T1-32B-Preview": "Sky-T1-32B-Preview", "openai/o1-mini": "o1-mini", - "openai/o1-preview": "o1-preview", + "openai/o1-preview": "o1-preview", "openai/gpt-4o-mini": "gpt-4o-mini", } @@ -127,4 +127,4 @@ Your response should be: True -""" \ No newline at end of file +""" diff --git a/skythought/tools/util/prompts.py b/skythought/tools/util/prompts.py index c530bee..c815db4 100644 --- a/skythought/tools/util/prompts.py +++ b/skythought/tools/util/prompts.py @@ -56,7 +56,8 @@ Make sure you include: <|begin_of_slow_thought|>, <|end_of_slow_thought|>, <|begin_of_solution|>,<|end_of_solution|> These four headers explicitly. \ Content to be converted: {content}" -convert_prompt_example = ("<|begin_of_thought|>\n\n" +convert_prompt_example = ( + "<|begin_of_thought|>\n\n" "Okay, so I've got this problem here. Mr. Wang leaves home at 6 AM, riding his bike at 12 km/h, " "and he stops to rest for 6 minutes after every 30 minutes of riding. Then, when he arrives at a park " "that's 16.8 km away, I need to find out the angle between the hour and minute hands on his watch.\n\n" @@ -106,7 +107,8 @@ "$$\\text{Angle} = |30H - 5.5M|$$\n\n" " - At 7:36, $H = 7$ and $M = 36$:\n\n" "$$\\text{Angle} = |30 \\times 7 - 5.5 \\times 36| = |210 - 198| = 12 \\text{ degrees}$$\n\n" - "Thus, the angle between the hour and minute hands on his watch is $\\boxed{12}$.<|end_of_solution|>\n") + "Thus, the angle between the hour and minute hands on his watch is $\\boxed{12}$.<|end_of_solution|>\n" +) # From https://arxiv.org/pdf/2412.09413 system_prompt = "Your role as an assistant involves thoroughly exploring questions through a systematic long \ @@ -127,4 +129,4 @@ <|begin_of_solution|> \ {final formatted, precise, and clear solution} \ <|end_of_solution|> \ -Now, try to solve the following question through the above guidelines:" \ No newline at end of file +Now, try to solve the following question through the above guidelines:" diff --git a/skythought/tools/util/task_handlers.py b/skythought/tools/util/task_handlers.py index 4f33eb5..3cb6895 100644 --- a/skythought/tools/util/task_handlers.py +++ b/skythought/tools/util/task_handlers.py @@ -4,17 +4,32 @@ import os import random import re +from multiprocessing import Manager +from typing import Any, Dict + import numpy as np from datasets import load_dataset -from typing import Dict, Any -from multiprocessing import Manager + from tasks.apps.apps_util import run_test as apps_run_test +from tasks.livecodebench.livecodebench_util import ( + map_to_example, + post_process_code, + translate_private_test_cases, + unsafe_lcb_runTests, +) from tasks.taco.taco_util import run_test as taco_run_test -from .math_parsing_util import strip_answer_string, get_multiple_choice_answer, extract_answer, math_equal, mmlu_pro_extract_answer -from tasks.livecodebench.livecodebench_util import unsafe_lcb_runTests, map_to_example, has_test_type, post_process_code, translate_private_test_cases -from .common import TimeoutException, timeout from util.model_utils import * +from .common import TimeoutException, timeout +from .math_parsing_util import ( + extract_answer, + get_multiple_choice_answer, + math_equal, + mmlu_pro_extract_answer, + strip_answer_string, +) + + def has_code(response): pattern = r"```(?:[a-zA-Z]*)\n(.*?)```" # Use re.DOTALL to match multiline content inside backticks @@ -22,6 +37,7 @@ def has_code(response): # print(matches) return matches + class TaskHandler: @staticmethod def get_question_key(): @@ -29,7 +45,7 @@ def get_question_key(): def check_correctness(self, problem, generation): raise NotImplementedError("Subclasses should implement this method.") - + def update_results(self, problem, response): raise NotImplementedError("Subclasses should implement this method.") @@ -39,28 +55,31 @@ def make_conversations(self, data, system_prompt, model=None): def load_existing_results(self, result_file): if not os.path.exists(result_file): return {} - with open(result_file, 'r', encoding='utf-8') as f: + with open(result_file, "r", encoding="utf-8") as f: records = json.load(f) return records - def load_and_filter_dataset(self, start, end, split="train", source=None, filter_difficulty=False, args=None): + def load_and_filter_dataset( + self, start, end, split="train", source=None, filter_difficulty=False, args=None + ): raise NotImplementedError("Subclasses should implement this method.") def process_remaining_data(self, train_data, results): raise NotImplementedError("Subclasses should implement this method.") - + + class MathTaskHandler(TaskHandler): @staticmethod def generate_prompt(prompt): return "Return your final response within \\boxed{{}}. " + prompt - + def check_correctness(self, problem, generation): answer = strip_answer_string(problem["answer"]) pred = extract_answer(generation) # print(problem) pred = strip_answer_string(pred) return math_equal(pred, answer) - + def update_results(self, problem, response): if not isinstance(response, str): response = response.outputs[0].text.strip() @@ -77,73 +96,91 @@ def update_results(self, problem, response): else: response_entry["correctness"] = False response_entry["reason"] = "Solution is incorrect." - + return response_entry - + def make_conversations(self, data, system_prompt, model=None): conversations = [] for problem in data: prompt_text = self.generate_prompt(problem["problem"]) - conversations.append([ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt_text} - ]) + conversations.append( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt_text}, + ] + ) return conversations - + def process_remaining_data(self, train_data, results): - return [row.to_dict() for _, row in train_data.iterrows() if str(row["problem"]) not in results] + return [ + row.to_dict() + for _, row in train_data.iterrows() + if str(row["problem"]) not in results + ] - def load_and_filter_dataset(self, start, end, split="test", source=None, filter_difficulty=False, args=None): + def load_and_filter_dataset( + self, start, end, split="test", source=None, filter_difficulty=False, args=None + ): dataset = load_dataset(self.dataset) train_data = dataset[split].to_pandas() return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] + class MATH500TaskHandler(MathTaskHandler): def __init__(self): self.dataset = "qq8933/MATH500" - + @staticmethod def get_question_key(): return "problem" + class AIMETaskHandler(MathTaskHandler): def __init__(self): self.dataset = "AI-MO/aimo-validation-aime" - + @staticmethod def generate_prompt(prompt, model): if MODEL_TO_NAME[model] == "Sky-T1-32B-Preview": return prompt + "\nReturn your final response within \\boxed{{}}" else: return "Return your final response within \\boxed{{}}. " + prompt - + @staticmethod def get_question_key(): return "problem" - + def make_conversations(self, data, system_prompt, model=None): conversations = [] for problem in data: prompt_text = self.generate_prompt(problem["problem"], model) - conversations.append([ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt_text} - ]) + conversations.append( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt_text}, + ] + ) return conversations - - def load_and_filter_dataset(self, start, end, split="train", source=None, filter_difficulty=False, args=None): + + def load_and_filter_dataset( + self, start, end, split="train", source=None, filter_difficulty=False, args=None + ): dataset = load_dataset(self.dataset) train_data = dataset[split].to_pandas() - filtered_data = train_data[train_data['url'].str.contains("2024", na=False)] + filtered_data = train_data[train_data["url"].str.contains("2024", na=False)] return filtered_data.iloc[start:end] if end > 0 else filtered_data.iloc[start:] - + + class GPQADiamondTaskHandler(TaskHandler): def __init__(self): self.dataset = "Idavidrein/gpqa" @staticmethod def generate_prompt(prompt): - return "Return your final response within \\boxed{{}} and only include the letter choice (A, B, C, or D) as your final response. " + prompt + return ( + "Return your final response within \\boxed{{}} and only include the letter choice (A, B, C, or D) as your final response. " + + prompt + ) @staticmethod def get_question_key(): @@ -165,54 +202,76 @@ def update_results(self, problem, response): else: response_entry["correctness"] = False response_entry["reason"] = "Solution is incorrect." - + return response_entry - + def check_correctness(self, problem, generation): pred = get_multiple_choice_answer(generation) answer = problem["Answer"] return answer == pred - + def get_multiple_choice_answers(self, data): answers = [ data["Correct Answer"], data["Incorrect Answer 1"], data["Incorrect Answer 2"], - data["Incorrect Answer 3"] + data["Incorrect Answer 3"], ] random.shuffle(answers) # Map options to letters options = ["A", "B", "C", "D"] - options_to_answers = {letter: answer for letter, answer in zip(options, answers)} + options_to_answers = { + letter: answer for letter, answer in zip(options, answers) + } # Format the options into the string - multiple_choice_string = ", ".join(f"{letter}) {options_to_answers[letter]}" for letter in options) + multiple_choice_string = ", ".join( + f"{letter}) {options_to_answers[letter]}" for letter in options + ) # Save the letter corresponding to the correct answer - correct_answer_letter = next(letter for letter, answer in options_to_answers.items() if answer == data["Correct Answer"]) + correct_answer_letter = next( + letter + for letter, answer in options_to_answers.items() + if answer == data["Correct Answer"] + ) return multiple_choice_string, correct_answer_letter - + def make_conversations(self, data, system_prompt, model=None): conversations = [] for problem in data: - multiple_choice_string, correct_answer_letter = self.get_multiple_choice_answers(problem) + ( + multiple_choice_string, + correct_answer_letter, + ) = self.get_multiple_choice_answers(problem) problem["Answer"] = correct_answer_letter - prompt_text = self.generate_prompt(problem["Question"] + "\n" + multiple_choice_string) - conversations.append([ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt_text} - ]) + prompt_text = self.generate_prompt( + problem["Question"] + "\n" + multiple_choice_string + ) + conversations.append( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt_text}, + ] + ) return conversations - def load_and_filter_dataset(self, start, end, split="train", source=None, filter_difficulty=False, args=None): + def load_and_filter_dataset( + self, start, end, split="train", source=None, filter_difficulty=False, args=None + ): dataset = load_dataset(self.dataset, "gpqa_diamond") train_data = dataset[split].to_pandas() return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] def process_remaining_data(self, train_data, results): - return [row.to_dict() for _, row in train_data.iterrows() if str(row["Question"]) not in results] + return [ + row.to_dict() + for _, row in train_data.iterrows() + if str(row["Question"]) not in results + ] + class MMLUTaskHandler(TaskHandler): def __init__(self): @@ -249,38 +308,66 @@ def update_results(self, problem, response): response_entry["correctness"] = False response_entry["reason"] = "Solution is incorrect." return response_entry - + def get_multiple_choice_answers(self, problem): options = problem["choices"] for i, (label, option) in enumerate(zip("ABCD", options)): options[i] = f"({label}) {str(option).strip()}" options = " ".join(options) return f"Answer Choices: {options}" - + def make_conversations(self, data, system_prompt, model=None): conversations = [] for problem in data: multiple_choice_string = self.get_multiple_choice_answers(problem) - prompt_text = self.generate_prompt(problem["question"] + "\n" + multiple_choice_string) - conversations.append([ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt_text} - ]) + prompt_text = self.generate_prompt( + problem["question"] + "\n" + multiple_choice_string + ) + conversations.append( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt_text}, + ] + ) return conversations def process_remaining_data(self, train_data, results): - return [row.to_dict() for _, row in train_data.iterrows() if str(row["question"]) not in results] + return [ + row.to_dict() + for _, row in train_data.iterrows() + if str(row["question"]) not in results + ] - def load_and_filter_dataset(self, start, end, split="test", source=None, filter_difficulty=False, args=None): + def load_and_filter_dataset( + self, start, end, split="test", source=None, filter_difficulty=False, args=None + ): dataset = load_dataset(self.dataset, "all") train_data = dataset[split].to_pandas() return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] + class MMLUProTaskHandler(MMLUTaskHandler): def __init__(self): super().__init__() self.dataset = "TIGER-Lab/MMLU-Pro" - self.choices = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P"] + self.choices = [ + "A", + "B", + "C", + "D", + "E", + "F", + "G", + "H", + "I", + "J", + "K", + "L", + "M", + "N", + "O", + "P", + ] @staticmethod def generate_prompt(prompt): @@ -297,16 +384,19 @@ def check_correctness(self, problem, generation): def get_multiple_choice_answers(self, problem): options = problem["options"] - for i, (label, option) in enumerate(zip(self.choices[:len(options)], options)): + for i, (label, option) in enumerate(zip(self.choices[: len(options)], options)): options[i] = f"({label}) {str(option).strip()}" options = " ".join(options) return f"Answer Choices: {options}" - def load_and_filter_dataset(self, start, end, split="test", source=None, filter_difficulty=False, args=None): + def load_and_filter_dataset( + self, start, end, split="test", source=None, filter_difficulty=False, args=None + ): dataset = load_dataset(self.dataset, "default") train_data = dataset[split].to_pandas() return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] - + + class NUMINATaskHandler(TaskHandler): @staticmethod def get_question_key(): @@ -315,7 +405,7 @@ def get_question_key(): @staticmethod def generate_prompt(prompt): return "Return your final response within \\boxed{{}}. " + prompt - + @timeout(5) # Add timeout of 5 seconds def check_correctness(self, problem, generation): solution = extract_answer(problem["solution"]) @@ -323,7 +413,7 @@ def check_correctness(self, problem, generation): pred = extract_answer(generation) pred = strip_answer_string(pred) return math_equal(pred, solution) - + def update_results(self, problem, response): if not isinstance(response, str): response = response.outputs[0].text.strip() @@ -351,7 +441,11 @@ def update_results(self, problem, response): @staticmethod def get_difficulty_dict(source, start, end): diff_dict = {} - dataset = load_dataset("NovaSky-AI/labeled_numina_difficulty_859K", trust_remote_code=True, split="train") + dataset = load_dataset( + "NovaSky-AI/labeled_numina_difficulty_859K", + trust_remote_code=True, + split="train", + ) for example in dataset: # print(example) diff_dict[example["problem"]] = example["gpt_difficulty_parsed"] @@ -361,24 +455,44 @@ def make_conversations(self, data, system_prompt, model=None): conversations = [] for problem in data: prompt_text = self.generate_prompt(problem["problem"]) - conversations.append([ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt_text} - ]) + conversations.append( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt_text}, + ] + ) return conversations - def load_and_filter_dataset(self, start, end, split="train", source=None, filter_difficulty=False, args=None): + def load_and_filter_dataset( + self, start, end, split="train", source=None, filter_difficulty=False, args=None + ): dataset = load_dataset("AI-MO/NuminaMath-CoT") train_data = dataset[split].to_pandas() - train_data = train_data.query('source == @source').iloc[start:end] if end > 0 else train_data.query('source == @source').iloc[start:] + train_data = ( + train_data.query("source == @source").iloc[start:end] + if end > 0 + else train_data.query("source == @source").iloc[start:] + ) train_data = train_data[train_data["solution"].str.contains("boxed", na=False)] if filter_difficulty: diff_dict = self.get_difficulty_dict(source, start, end) - train_data = train_data[train_data["problem"].map(diff_dict).apply(lambda x: x >= args.math_difficulty_lower_bound and x <= args.math_difficulty_upper_bound)] + train_data = train_data[ + train_data["problem"] + .map(diff_dict) + .apply( + lambda x: x >= args.math_difficulty_lower_bound + and x <= args.math_difficulty_upper_bound + ) + ] return train_data def process_remaining_data(self, train_data, results): - return [row.to_dict() for _, row in train_data.iterrows() if str(row["problem"]) not in results] + return [ + row.to_dict() + for _, row in train_data.iterrows() + if str(row["problem"]) not in results + ] + class APPSTaskHandler(TaskHandler): @staticmethod @@ -390,38 +504,43 @@ def generate_prompt(test_case, prompt, starter_code=None): _input = "" data = test_case if not data.get("fn_name"): - _input += "Generate an executable Python function generated from the given prompt. The function should take stdin as input and print the output. Simply call the function after the definition."# "\nUse Standard Input format"#\n" + _input += "Generate an executable Python function generated from the given prompt. The function should take stdin as input and print the output. Simply call the function after the definition." # "\nUse Standard Input format"#\n" else: - _input += "Generate an executable Python function generated from the given prompt. Return the function body without invoking it at the final solution." #"\nUse Call-Based format"#\n" + _input += "Generate an executable Python function generated from the given prompt. Return the function body without invoking it at the final solution." # "\nUse Call-Based format"#\n" data = prompt _input += data if starter_code != None: data = starter_code - data = "\n" + data #+ "\n" + data = "\n" + data # + "\n" _input += data else: - #_input += "\n\n" + # _input += "\n\n" pass - + return _input - + def check_correctness(self, problem, generation): TIMEOUT = 10 + def _temp_run(problem, generation, debug, result): try: - result.append(apps_run_test(problem=problem, test=generation, debug=debug)) - except Exception as e: + result.append( + apps_run_test(problem=problem, test=generation, debug=debug) + ) + except Exception: pass manager = Manager() result = manager.list() - p = multiprocessing.Process(target=_temp_run, args=(problem, generation, False, result)) + p = multiprocessing.Process( + target=_temp_run, args=(problem, generation, False, result) + ) p.start() p.join(timeout=TIMEOUT + 1) if p.is_alive(): p.kill() return bool(result and np.all(result[0])) - + def update_results(self, problem, response): if not isinstance(response, str): response = response.outputs[0].text.strip() @@ -435,15 +554,15 @@ def update_results(self, problem, response): if len(code_filter_result) == 0: response_entry["correctness"] = False response_entry["reason"] = "Does not contain code component." - else: + else: last_code = code_filter_result[-1] problem_to_check = copy.deepcopy(problem) problem_to_check["input_output"] = json.loads(problem["input_output"]) try: - problem_to_check["solutions"] = json.loads(problem["solutions"]) + problem_to_check["solutions"] = json.loads(problem["solutions"]) except: problem_to_check["solutions"] = "" - print(f"Empty solution from the dataset") + print("Empty solution from the dataset") curr_res = self.check_correctness(problem_to_check, generation=last_code) if curr_res: response_entry["correctness"] = True @@ -451,7 +570,7 @@ def update_results(self, problem, response): else: response_entry["correctness"] = False response_entry["reason"] = "Code is incorrect." - + return response_entry def make_conversations(self, data, system_prompt, model=None): @@ -459,22 +578,37 @@ def make_conversations(self, data, system_prompt, model=None): for problem in data: test_case = json.loads(problem["input_output"]) starter_code = problem["starter_code"] - prompt_text = self.generate_prompt(test_case, problem["question"], starter_code) - conversations.append([ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt_text} - ]) + prompt_text = self.generate_prompt( + test_case, problem["question"], starter_code + ) + conversations.append( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt_text}, + ] + ) return conversations - def load_and_filter_dataset(self, start, end, split="train", source=None, filter_difficulty=False, args=None): + def load_and_filter_dataset( + self, start, end, split="train", source=None, filter_difficulty=False, args=None + ): dataset = load_dataset("codeparrot/apps", trust_remote_code=True) train_data = dataset[split].to_pandas() if not filter_difficulty: return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] - return train_data.query('difficulty == @source').iloc[start:end] if end > 0 else train_data.query('difficulty == @source').iloc[start:] + return ( + train_data.query("difficulty == @source").iloc[start:end] + if end > 0 + else train_data.query("difficulty == @source").iloc[start:] + ) def process_remaining_data(self, train_data, results): - return [row.to_dict() for _, row in train_data.iterrows() if str(row["question"]) not in results] + return [ + row.to_dict() + for _, row in train_data.iterrows() + if str(row["question"]) not in results + ] + class TACOTaskHandler(TaskHandler): @staticmethod @@ -494,11 +628,12 @@ def generate_prompt(prompt, starter_code=None, fn_name=None): call_format = "\nUse Call-Based format" _input += call_format _input += "\nANSWER:\n" - + return _input - + def check_correctness(self, problem, generation): TIME_OUT = 300 + def _temp_run(problem, generation, debug, result): try: result.append(taco_run_test(problem, test=generation, debug=debug)) @@ -507,13 +642,15 @@ def _temp_run(problem, generation, debug, result): manager = Manager() result = manager.list() - p = multiprocessing.Process(target=_temp_run, args=(problem, generation, False, result)) + p = multiprocessing.Process( + target=_temp_run, args=(problem, generation, False, result) + ) p.start() p.join(timeout=TIME_OUT + 1) if p.is_alive(): p.kill() return bool(result and np.all(result[0])) - + def update_results(self, problem, response): if not isinstance(response, str): response = response.outputs[0].text.strip() @@ -527,7 +664,7 @@ def update_results(self, problem, response): if len(code_filter_result) == 0: response_entry["correctness"] = False response_entry["reason"] = "Does not contain code component." - else: + else: last_code = code_filter_result[-1] curr_res = self.check_correctness(problem, generation=last_code) if curr_res: @@ -536,36 +673,55 @@ def update_results(self, problem, response): else: response_entry["correctness"] = False response_entry["reason"] = "Code is incorrect." - + return response_entry def make_conversations(self, data, system_prompt, model=None): conversations = [] for idx, problem in enumerate(data): - starter_code = None if len(problem["starter_code"]) == 0 else problem["starter_code"] + starter_code = ( + None if len(problem["starter_code"]) == 0 else problem["starter_code"] + ) try: input_outpout = json.loads(problem["input_output"]) fn_name = ( - None if not input_outpout.get("fn_name") else input_outpout["fn_name"] + None + if not input_outpout.get("fn_name") + else input_outpout["fn_name"] ) except ValueError: fn_name = None - prompt_text = self.generate_prompt(problem["question"], starter_code, fn_name) - conversations.append([ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt_text} - ]) + prompt_text = self.generate_prompt( + problem["question"], starter_code, fn_name + ) + conversations.append( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt_text}, + ] + ) return conversations - def load_and_filter_dataset(self, start, end, split="train", source=None, filter_difficulty=False, args=None): + def load_and_filter_dataset( + self, start, end, split="train", source=None, filter_difficulty=False, args=None + ): dataset = load_dataset("BAAI/TACO", "ALL", trust_remote_code=True) train_data = dataset[split].to_pandas() if not filter_difficulty: return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] - return train_data.query('difficulty == @source').iloc[start:end] if end > 0 else train_data.query('difficulty == @source').iloc[start:] + return ( + train_data.query("difficulty == @source").iloc[start:end] + if end > 0 + else train_data.query("difficulty == @source").iloc[start:] + ) def process_remaining_data(self, train_data, results): - return [row.to_dict() for _, row in train_data.iterrows() if str(row["question"]) not in results] + return [ + row.to_dict() + for _, row in train_data.iterrows() + if str(row["question"]) not in results + ] + class LiveCodeBenchTaskHandler(TaskHandler): @staticmethod @@ -573,14 +729,20 @@ def generate_prompt(problem): # print(problem) prompt = problem["prompt"] if problem["is_stdin"]: - return "Generate an executable Python function generated from the given prompt. The function should take stdin as input and print the output. Simply call the function after the definition." + prompt + return ( + "Generate an executable Python function generated from the given prompt. The function should take stdin as input and print the output. Simply call the function after the definition." + + prompt + ) else: - return "Generate an executable Python function generated from the given prompt. Return the function body without invoking it at the final solution." + prompt - + return ( + "Generate an executable Python function generated from the given prompt. Return the function body without invoking it at the final solution." + + prompt + ) + @staticmethod def get_question_key(): return "task_id" - + def check_correctness( self, problem: Dict, @@ -596,16 +758,18 @@ def check_correctness( :param completion_id: an optional completion ID so we can match the results later even if execution finishes asynchronously. """ - result_list = unsafe_lcb_runTests(problem, completion, timeout, runtime_debug, is_extracted) + result_list = unsafe_lcb_runTests( + problem, completion, timeout, runtime_debug, is_extracted + ) details = [r[0] for r in result_list] all_passed = all(details) - + result = "" if result_list and all_passed: result = "passed" return result == "passed" - + def update_results(self, problem, response): if not isinstance(response, str): response = response.outputs[0].text.strip() @@ -620,52 +784,75 @@ def update_results(self, problem, response): if len(code_filter_result) == 0: response_entry["correctness"] = False response_entry["reason"] = "Does not contain code component." - else: + else: last_code = code_filter_result[-1] problem_to_check = copy.deepcopy(problem) - curr_res = self.check_correctness(problem=problem_to_check, completion=post_process_code(last_code), timeout=6, is_extracted=not problem_to_check["is_stdin"]) + curr_res = self.check_correctness( + problem=problem_to_check, + completion=post_process_code(last_code), + timeout=6, + is_extracted=not problem_to_check["is_stdin"], + ) if curr_res: response_entry["correctness"] = True response_entry["reason"] = "" else: response_entry["correctness"] = False response_entry["reason"] = "Code is incorrect." - + return response_entry def make_conversations(self, data, system_prompt, model=None): conversations = [] for problem in data: prompt_text = self.generate_prompt(problem) - conversations.append([ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt_text} - ]) + conversations.append( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt_text}, + ] + ) return conversations - def load_and_filter_dataset(self, start, end, split="test", source=None, filter_difficulty=False, args=None): - dataset = load_dataset("livecodebench/code_generation_lite", version_tag="release_v2", split=split, trust_remote_code=True) + def load_and_filter_dataset( + self, start, end, split="test", source=None, filter_difficulty=False, args=None + ): + dataset = load_dataset( + "livecodebench/code_generation_lite", + version_tag="release_v2", + split=split, + trust_remote_code=True, + ) if filter_difficulty: - dataset = dataset.filter(lambda example: example['difficulty'] == source) + dataset = dataset.filter(lambda example: example["difficulty"] == source) dataset = dataset.map( lambda example: { - "private_test_cases": translate_private_test_cases(example["private_test_cases"]) + "private_test_cases": translate_private_test_cases( + example["private_test_cases"] + ) } ) # Apply the mapping function - dataset = dataset.map(map_to_example, remove_columns=dataset.column_names).to_pandas() + dataset = dataset.map( + map_to_example, remove_columns=dataset.column_names + ).to_pandas() return dataset.iloc[start:end] if end > 0 else dataset.iloc[start:] def process_remaining_data(self, train_data, results): - return [row.to_dict() for _, row in train_data.iterrows() if str(row["task_id"]) not in results] + return [ + row.to_dict() + for _, row in train_data.iterrows() + if str(row["task_id"]) not in results + ] + class GSM8KTaskHandler(TaskHandler): def __init__(self) -> None: super().__init__() self.dataset = "openai/gsm8k" self.ans_re = re.compile(r"((-?[$0-9.,]{2,})|(-?[0-9]+))") - self.gt_re = re.compile(r"#### (\-?[0-9\.\,]+)") + self.gt_re = re.compile(r"#### (\-?[0-9\.\,]+)") self.invalid_ans = "[invalid]" @staticmethod @@ -674,16 +861,16 @@ def get_question_key(): @staticmethod def generate_prompt(problem): - question = problem["question"] - full_prompt = f"Given the following problem, reason and give a final answer to the problem.\nProblem: {question}\nYour response should end with \"The final answer is [answer]\" where [answer] is the response to the problem." + question = problem["question"] + full_prompt = f'Given the following problem, reason and give a final answer to the problem.\nProblem: {question}\nYour response should end with "The final answer is [answer]" where [answer] is the response to the problem.' return full_prompt - - def check_correctness(self, problem: Dict[str, Any], generation: str) -> bool: + + def check_correctness(self, problem: Dict[str, Any], generation: str) -> bool: gt_answer = self.extract_gt_answer(problem["answer"]) model_answer = extract_answer(generation) model_answer = self.sanitize_answer(model_answer) return model_answer == gt_answer - + def update_results(self, problem, response): if not isinstance(response, str): response = response.outputs[0].text.strip() @@ -693,34 +880,42 @@ def update_results(self, problem, response): "correctness": None, "reason": None, } - curr_res= self.check_correctness(problem, generation=response) + curr_res = self.check_correctness(problem, generation=response) if curr_res: response_entry["correctness"] = True response_entry["reason"] = "" else: response_entry["correctness"] = False response_entry["reason"] = "Solution is incorrect." - + return response_entry def make_conversations(self, data, system_prompt, model=None): conversations = [] for problem in data: prompt_text = self.generate_prompt(problem) - conversations.append([ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt_text} - ]) + conversations.append( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt_text}, + ] + ) return conversations - def load_and_filter_dataset(self, start, end, split="train", source=None, filter_difficulty=False, args=None): + def load_and_filter_dataset( + self, start, end, split="train", source=None, filter_difficulty=False, args=None + ): dataset = load_dataset(self.dataset, "main") train_data = dataset[split].to_pandas() return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] def process_remaining_data(self, train_data, results): - return [row.to_dict() for _, row in train_data.iterrows() if str(row["question"]) not in results] - + return [ + row.to_dict() + for _, row in train_data.iterrows() + if str(row["question"]) not in results + ] + def extract_gt_answer(self, completion): match = self.gt_re.search(completion) if match: @@ -732,14 +927,13 @@ def extract_gt_answer(self, completion): def sanitize_answer(self, answer): patterns_to_remove = [ - ',', # Remove commas - r'\$', # Remove dollar signs - r'\.$' # Remove trailing period - r"\*", # Remove asterisks + ",", # Remove commas + r"\$", # Remove dollar signs + r"\.$" r"\*", # Remove trailing period # Remove asterisks ] for pattern in patterns_to_remove: - answer = re.sub(pattern, '', answer) - + answer = re.sub(pattern, "", answer) + matches = self.ans_re.findall(answer) if matches: # get the last match (i.e final response) and the first / outer capturing group @@ -748,12 +942,13 @@ def sanitize_answer(self, answer): else: return self.invalid_ans -class ARCChallengeTaskHandler(TaskHandler): + +class ARCChallengeTaskHandler(TaskHandler): def __init__(self) -> None: super().__init__() self.dataset = "allenai/ai2_arc" self.ans_re = re.compile(r"[Tt]he best answer is ([A-D])[\.\,]*", re.IGNORECASE) - self.letter_re = re.compile(r"([A-D])[\.\,]*") + self.letter_re = re.compile(r"([A-D])[\.\,]*") self.canonical_options = ["A", "B", "C", "D"] self.invalid_ans = "[invalid]" @@ -763,19 +958,27 @@ def get_question_key(): @staticmethod def generate_prompt(problem): - question = problem["question"] + question = problem["question"] choices = problem["choices"] - choices_text = '\n'.join([f"{label}.{choice}" for label, choice in zip(["A", "B", "C", "D"], choices["text"])]) - full_prompt = "Given the following question and four candidate answers (A, B, C and D), choose the best answer. Your response should end with \"The best answer is [the_answer_letter]\" where [the_answer_letter] is one of the four letter choice (A, B, C, or D).\n" + f"{question}\n{choices_text}" + choices_text = "\n".join( + [ + f"{label}.{choice}" + for label, choice in zip(["A", "B", "C", "D"], choices["text"]) + ] + ) + full_prompt = ( + 'Given the following question and four candidate answers (A, B, C and D), choose the best answer. Your response should end with "The best answer is [the_answer_letter]" where [the_answer_letter] is one of the four letter choice (A, B, C, or D).\n' + + f"{question}\n{choices_text}" + ) return full_prompt - - def check_correctness(self, problem: Dict[str, Any], generation: str) -> bool: + + def check_correctness(self, problem: Dict[str, Any], generation: str) -> bool: gt_answer = problem["answerKey"] if gt_answer not in self.canonical_options: gt_answer = self.canonical_options[int(problem["answerKey"]) - 1] model_answer = self.get_answer(generation) return model_answer == gt_answer - + def update_results(self, problem, response): if not isinstance(response, str): response = response.outputs[0].text.strip() @@ -792,54 +995,62 @@ def update_results(self, problem, response): else: response_entry["correctness"] = False response_entry["reason"] = "Solution is incorrect." - + return response_entry def make_conversations(self, data, system_prompt, model=None): conversations = [] for problem in data: prompt_text = self.generate_prompt(problem) - conversations.append([ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt_text} - ]) + conversations.append( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt_text}, + ] + ) return conversations - def load_and_filter_dataset(self, start, end, split="train", source=None, filter_difficulty=False, args=None): + def load_and_filter_dataset( + self, start, end, split="train", source=None, filter_difficulty=False, args=None + ): dataset = load_dataset(self.dataset, "ARC-Challenge") train_data = dataset[split].to_pandas() return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] def process_remaining_data(self, train_data, results): - return [row.to_dict() for _, row in train_data.iterrows() if str(row["question"]) not in results] + return [ + row.to_dict() + for _, row in train_data.iterrows() + if str(row["question"]) not in results + ] def get_answer(self, completion): # First, we try to extract similar to MATH answers answer = extract_answer(completion) match = None - if answer: - # match for the letter answer needed. + if answer: + # match for the letter answer needed. match = self.letter_re.search(answer) - if match: + if match: return match.group(1).strip() - - if not answer or not match: - # try basic-regex based search + + if not answer or not match: + # try basic-regex based search patterns_to_remove = [ - ',', # Remove commas - r'\$', # Remove dollar signs - r'\.$' # Remove trailing period - r"\\", # Remove stray backslashes - r"\*", # Remove asterisks + ",", # Remove commas + r"\$", # Remove dollar signs + r"\.$" r"\\", # Remove trailing period # Remove stray backslashes + r"\*", # Remove asterisks ] answer = completion for pattern in patterns_to_remove: - answer = re.sub(pattern, '', answer) + answer = re.sub(pattern, "", answer) matches = self.ans_re.findall(answer) - if not matches: + if not matches: return self.invalid_ans return matches[-1].strip() + class AMC23TaskHandler(MathTaskHandler): def __init__(self): self.dataset = "AI-MO/aimo-validation-amc" @@ -847,11 +1058,13 @@ def __init__(self): @staticmethod def get_question_key(): return "problem" - - def load_and_filter_dataset(self, start, end, split="train", source=None, filter_difficulty=False, args=None): + + def load_and_filter_dataset( + self, start, end, split="train", source=None, filter_difficulty=False, args=None + ): dataset = load_dataset(self.dataset) train_data = dataset[split].to_pandas() - filtered_data = train_data[train_data['url'].str.contains("2023", na=False)] + filtered_data = train_data[train_data["url"].str.contains("2023", na=False)] return filtered_data.iloc[start:end] if end > 0 else filtered_data.iloc[start:] From 0cf35489b234e6e7ad62f0dace2479e1d7d083cf Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Fri, 24 Jan 2025 00:46:42 +0000 Subject: [PATCH 03/47] add a bunch of stuff Signed-off-by: SumanthRH --- skythought/tools/.githooks/pre-commit | 1 + 1 file changed, 1 insertion(+) diff --git a/skythought/tools/.githooks/pre-commit b/skythought/tools/.githooks/pre-commit index 094125f..b28029b 100644 --- a/skythought/tools/.githooks/pre-commit +++ b/skythought/tools/.githooks/pre-commit @@ -1,3 +1,4 @@ +set -e # Only run pre-commit if changes are in tools/ if git diff --cached --name-only | grep "^tools/"; then cd skythought/tools/ From f48b5dc8bff837682fe4e405f3637321501c4f16 Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Fri, 24 Jan 2025 00:54:18 +0000 Subject: [PATCH 04/47] check Signed-off-by: SumanthRH --- skythought/tools/.githooks/pre-commit | 0 skythought/tools/format.sh | 4 ++++ 2 files changed, 4 insertions(+) mode change 100644 => 100755 skythought/tools/.githooks/pre-commit diff --git a/skythought/tools/.githooks/pre-commit b/skythought/tools/.githooks/pre-commit old mode 100644 new mode 100755 diff --git a/skythought/tools/format.sh b/skythought/tools/format.sh index 16cbd4e..34ef0df 100644 --- a/skythought/tools/format.sh +++ b/skythought/tools/format.sh @@ -9,6 +9,10 @@ else pip install -q pre-commit fi +# Hook file should be executable +HOOK_SCRIPT=$TOOLS_DIR/.githooks/pre-commit +chmod +x $HOOK_SCRIPT + git config --local core.hooksPath "$TOOLS_DIR/.githooks" # pre-commit run --all-files always runs from the root directory. we run this only on tools/ for now. pre-commit run --files $TOOLS_DIR/* \ No newline at end of file From 3e624af433c46a46537883503d1a8bf9a58e3ea0 Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Fri, 24 Jan 2025 00:54:58 +0000 Subject: [PATCH 05/47] check Signed-off-by: SumanthRH --- skythought/tools/tasks/aime/aime_handler.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/skythought/tools/tasks/aime/aime_handler.py b/skythought/tools/tasks/aime/aime_handler.py index b8c54dd..e0b745a 100644 --- a/skythought/tools/tasks/aime/aime_handler.py +++ b/skythought/tools/tasks/aime/aime_handler.py @@ -38,3 +38,7 @@ def load_and_filter_dataset( train_data = dataset[split].to_pandas() filtered_data = train_data[train_data["url"].str.contains("2024", na=False)] return filtered_data.iloc[start:end] if end > 0 else filtered_data.iloc[start:] + + def dummy(): + raise NotImplementedError() + From 859f0230cd5c99d431d5006685ae35efac0d5c2e Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Fri, 24 Jan 2025 00:59:30 +0000 Subject: [PATCH 06/47] check Signed-off-by: SumanthRH --- skythought/tools/.githooks/pre-commit | 8 +++++--- skythought/tools/format.sh | 2 +- skythought/tools/tasks/aime/aime_handler.py | 1 - 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/skythought/tools/.githooks/pre-commit b/skythought/tools/.githooks/pre-commit index b28029b..2cc1e30 100755 --- a/skythought/tools/.githooks/pre-commit +++ b/skythought/tools/.githooks/pre-commit @@ -1,6 +1,8 @@ set -e + +# Get tools directory path relative to git root +TOOLS_DIR=$(git rev-parse --show-toplevel)/skythought/tools # Only run pre-commit if changes are in tools/ -if git diff --cached --name-only | grep "^tools/"; then - cd skythought/tools/ - pre-commit run --files $(git diff --cached --name-only | grep "^tools/") +if git diff --cached --name-only | grep "^skythought/tools/"; then + pre-commit run --files $(git diff --cached --name-only | grep "^skythought/tools/") --config $TOOLS_DIR/.pre-commit-config.yaml fi \ No newline at end of file diff --git a/skythought/tools/format.sh b/skythought/tools/format.sh index 34ef0df..75d52f5 100644 --- a/skythought/tools/format.sh +++ b/skythought/tools/format.sh @@ -15,4 +15,4 @@ chmod +x $HOOK_SCRIPT git config --local core.hooksPath "$TOOLS_DIR/.githooks" # pre-commit run --all-files always runs from the root directory. we run this only on tools/ for now. -pre-commit run --files $TOOLS_DIR/* \ No newline at end of file +pre-commit run --files $TOOLS_DIR/* --config $TOOLS_DIR/.pre-commit-config.yaml \ No newline at end of file diff --git a/skythought/tools/tasks/aime/aime_handler.py b/skythought/tools/tasks/aime/aime_handler.py index e0b745a..c0504d8 100644 --- a/skythought/tools/tasks/aime/aime_handler.py +++ b/skythought/tools/tasks/aime/aime_handler.py @@ -41,4 +41,3 @@ def load_and_filter_dataset( def dummy(): raise NotImplementedError() - From d877bee21df257daf7c51eac0005c4c7c0ecb69f Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Fri, 24 Jan 2025 01:24:05 +0000 Subject: [PATCH 07/47] more refactoring Signed-off-by: SumanthRH --- skythought/tools/inference_and_check.py | 2 +- skythought/tools/tasks/__init__.py | 46 + skythought/tools/tasks/arc/arc_handler.py | 115 +++ skythought/tools/tasks/common.py | 183 +--- skythought/tools/util/task_handlers.py | 1084 --------------------- 5 files changed, 169 insertions(+), 1261 deletions(-) create mode 100644 skythought/tools/tasks/arc/arc_handler.py delete mode 100644 skythought/tools/util/task_handlers.py diff --git a/skythought/tools/inference_and_check.py b/skythought/tools/inference_and_check.py index 4deb736..816477e 100644 --- a/skythought/tools/inference_and_check.py +++ b/skythought/tools/inference_and_check.py @@ -10,8 +10,8 @@ from tqdm import tqdm from vllm import LLM, SamplingParams +from tasks import TASK_HANDLERS, NUMINATaskHandler, TaskHandler from util.model_utils import MODEL_TO_NAME, SYSTEM_PROMPT -from util.task_handlers import TASK_HANDLERS, NUMINATaskHandler, TaskHandler class NumpyEncoder(json.JSONEncoder): diff --git a/skythought/tools/tasks/__init__.py b/skythought/tools/tasks/__init__.py index e69de29..8dc2388 100644 --- a/skythought/tools/tasks/__init__.py +++ b/skythought/tools/tasks/__init__.py @@ -0,0 +1,46 @@ +from .aime.aime_handler import AIMETaskHandler +from .amc23.amc23_handler import AMC23TaskHandler +from .apps.apps_handler import APPSTaskHandler +from .arc.arc_handler import ARCChallengeTaskHandler +from .common import TaskHandler +from .gpqa_diamond.gpqa_diamond_handler import GPQADiamondTaskHandler +from .gsm8k.gsm8k_handler import GSM8KTaskHandler +from .livecodebench.livecodebench_handler import LiveCodeBenchTaskHandler +from .math.math_handler import MATH500TaskHandler, MathTaskHandler +from .mmlu.mmlu_handler import MMLUProTaskHandler, MMLUTaskHandler +from .numina.numina_handler import NUMINATaskHandler +from .taco.taco_handler import TACOTaskHandler + +TASK_HANDLERS = { + "NUMINA": NUMINATaskHandler, + "APPS": APPSTaskHandler, + "TACO": TACOTaskHandler, + "MATH500": MATH500TaskHandler, + "AIME": AIMETaskHandler, + "GPQADiamond": GPQADiamondTaskHandler, + "MMLU": MMLUTaskHandler, + "MMLUPro": MMLUProTaskHandler, + "LiveCodeBench": LiveCodeBenchTaskHandler, + "GSM8K": GSM8KTaskHandler, + "ARC-C": ARCChallengeTaskHandler, + "AMC23": AMC23TaskHandler, +} + + +__all__ = [ + AIMETaskHandler, + APPSTaskHandler, + TACOTaskHandler, + MATH500TaskHandler, + AMC23TaskHandler, + NUMINATaskHandler, + GPQADiamondTaskHandler, + MMLUTaskHandler, + MMLUProTaskHandler, + LiveCodeBenchTaskHandler, + GSM8KTaskHandler, + ARCChallengeTaskHandler, + TaskHandler, + MathTaskHandler, + TASK_HANDLERS, +] diff --git a/skythought/tools/tasks/arc/arc_handler.py b/skythought/tools/tasks/arc/arc_handler.py new file mode 100644 index 0000000..3c6c3aa --- /dev/null +++ b/skythought/tools/tasks/arc/arc_handler.py @@ -0,0 +1,115 @@ +import re +from typing import Any, Dict + +from datasets import load_dataset + +from tasks.common import TaskHandler +from util.math_parsing_util import extract_answer + + +class ARCChallengeTaskHandler(TaskHandler): + def __init__(self) -> None: + super().__init__() + self.dataset = "allenai/ai2_arc" + self.ans_re = re.compile(r"[Tt]he best answer is ([A-D])[\.\,]*", re.IGNORECASE) + self.letter_re = re.compile(r"([A-D])[\.\,]*") + self.canonical_options = ["A", "B", "C", "D"] + self.invalid_ans = "[invalid]" + + @staticmethod + def get_question_key(): + return "question" + + @staticmethod + def generate_prompt(problem): + question = problem["question"] + choices = problem["choices"] + choices_text = "\n".join( + [ + f"{label}.{choice}" + for label, choice in zip(["A", "B", "C", "D"], choices["text"]) + ] + ) + full_prompt = ( + 'Given the following question and four candidate answers (A, B, C and D), choose the best answer. Your response should end with "The best answer is [the_answer_letter]" where [the_answer_letter] is one of the four letter choice (A, B, C, or D).\n' + + f"{question}\n{choices_text}" + ) + return full_prompt + + def check_correctness(self, problem: Dict[str, Any], generation: str) -> bool: + gt_answer = problem["answerKey"] + if gt_answer not in self.canonical_options: + gt_answer = self.canonical_options[int(problem["answerKey"]) - 1] + model_answer = self.get_answer(generation) + return model_answer == gt_answer + + def update_results(self, problem, response): + if not isinstance(response, str): + response = response.outputs[0].text.strip() + # Initialize the response structure + response_entry = { + "content": response, + "correctness": None, + "reason": None, + } + curr_res = self.check_correctness(problem, generation=response) + if curr_res: + response_entry["correctness"] = True + response_entry["reason"] = "" + else: + response_entry["correctness"] = False + response_entry["reason"] = "Solution is incorrect." + + return response_entry + + def make_conversations(self, data, system_prompt, model=None): + conversations = [] + for problem in data: + prompt_text = self.generate_prompt(problem) + conversations.append( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt_text}, + ] + ) + return conversations + + def load_and_filter_dataset( + self, start, end, split="train", source=None, filter_difficulty=False, args=None + ): + dataset = load_dataset(self.dataset, "ARC-Challenge") + train_data = dataset[split].to_pandas() + return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] + + def process_remaining_data(self, train_data, results): + return [ + row.to_dict() + for _, row in train_data.iterrows() + if str(row["question"]) not in results + ] + + def get_answer(self, completion): + # First, we try to extract similar to MATH answers + answer = extract_answer(completion) + match = None + if answer: + # match for the letter answer needed. + match = self.letter_re.search(answer) + if match: + return match.group(1).strip() + + if not answer or not match: + # try basic-regex based search + patterns_to_remove = [ + ",", # Remove commas + r"\$", # Remove dollar signs + r"\.$" r"\\", # Remove trailing period # Remove stray backslashes + r"\*", # Remove asterisks + ] + answer = completion + for pattern in patterns_to_remove: + answer = re.sub(pattern, "", answer) + matches = self.ans_re.findall(answer) + if not matches: + return self.invalid_ans + return matches[-1].strip() diff --git a/skythought/tools/tasks/common.py b/skythought/tools/tasks/common.py index 0b10d97..cd6e88a 100644 --- a/skythought/tools/tasks/common.py +++ b/skythought/tools/tasks/common.py @@ -1,19 +1,7 @@ -import copy import json -import multiprocessing import os -import random import re -import numpy as np -from datasets import load_dataset -from typing import Dict, Any -from multiprocessing import Manager -from tasks.apps.apps_util import run_test as apps_run_test -from tasks.taco.taco_util import run_test as taco_run_test -from ..util.math_parsing_util import strip_answer_string, get_multiple_choice_answer, extract_answer, math_equal, mmlu_pro_extract_answer -from tasks.livecodebench.livecodebench_util import unsafe_lcb_runTests, map_to_example, has_test_type, post_process_code, translate_private_test_cases -from ..util.common import TimeoutException, timeout -from util.model_utils import * + def has_code(response): pattern = r"```(?:[a-zA-Z]*)\n(.*?)```" @@ -22,6 +10,7 @@ def has_code(response): # print(matches) return matches + class TaskHandler: @staticmethod def get_question_key(): @@ -29,7 +18,7 @@ def get_question_key(): def check_correctness(self, problem, generation): raise NotImplementedError("Subclasses should implement this method.") - + def update_results(self, problem, response): raise NotImplementedError("Subclasses should implement this method.") @@ -39,172 +28,14 @@ def make_conversations(self, data, system_prompt, model=None): def load_existing_results(self, result_file): if not os.path.exists(result_file): return {} - with open(result_file, 'r', encoding='utf-8') as f: + with open(result_file, "r", encoding="utf-8") as f: records = json.load(f) return records - def load_and_filter_dataset(self, start, end, split="train", source=None, filter_difficulty=False, args=None): + def load_and_filter_dataset( + self, start, end, split="train", source=None, filter_difficulty=False, args=None + ): raise NotImplementedError("Subclasses should implement this method.") def process_remaining_data(self, train_data, results): raise NotImplementedError("Subclasses should implement this method.") - - -class MathTaskHandler(TaskHandler): - @staticmethod - def generate_prompt(prompt): - return "Return your final response within \\boxed{{}}. " + prompt - - def check_correctness(self, problem, generation): - answer = strip_answer_string(problem["answer"]) - pred = extract_answer(generation) - # print(problem) - pred = strip_answer_string(pred) - return math_equal(pred, answer) - - def update_results(self, problem, response): - if not isinstance(response, str): - response = response.outputs[0].text.strip() - # Initialize the response structure - response_entry = { - "content": response, - "correctness": None, - "reason": None, - } - curr_res = self.check_correctness(problem, generation=response) - if curr_res: - response_entry["correctness"] = True - response_entry["reason"] = "" - else: - response_entry["correctness"] = False - response_entry["reason"] = "Solution is incorrect." - - return response_entry - - def make_conversations(self, data, system_prompt, model=None): - conversations = [] - for problem in data: - prompt_text = self.generate_prompt(problem["problem"]) - conversations.append([ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt_text} - ]) - return conversations - - def process_remaining_data(self, train_data, results): - return [row.to_dict() for _, row in train_data.iterrows() if str(row["problem"]) not in results] - - def load_and_filter_dataset(self, start, end, split="test", source=None, filter_difficulty=False, args=None): - dataset = load_dataset(self.dataset) - train_data = dataset[split].to_pandas() - return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] - - - -class ARCChallengeTaskHandler(TaskHandler): - def __init__(self) -> None: - super().__init__() - self.dataset = "allenai/ai2_arc" - self.ans_re = re.compile(r"[Tt]he best answer is ([A-D])[\.\,]*", re.IGNORECASE) - self.letter_re = re.compile(r"([A-D])[\.\,]*") - self.canonical_options = ["A", "B", "C", "D"] - self.invalid_ans = "[invalid]" - - @staticmethod - def get_question_key(): - return "question" - - @staticmethod - def generate_prompt(problem): - question = problem["question"] - choices = problem["choices"] - choices_text = '\n'.join([f"{label}.{choice}" for label, choice in zip(["A", "B", "C", "D"], choices["text"])]) - full_prompt = "Given the following question and four candidate answers (A, B, C and D), choose the best answer. Your response should end with \"The best answer is [the_answer_letter]\" where [the_answer_letter] is one of the four letter choice (A, B, C, or D).\n" + f"{question}\n{choices_text}" - return full_prompt - - def check_correctness(self, problem: Dict[str, Any], generation: str) -> bool: - gt_answer = problem["answerKey"] - if gt_answer not in self.canonical_options: - gt_answer = self.canonical_options[int(problem["answerKey"]) - 1] - model_answer = self.get_answer(generation) - return model_answer == gt_answer - - def update_results(self, problem, response): - if not isinstance(response, str): - response = response.outputs[0].text.strip() - # Initialize the response structure - response_entry = { - "content": response, - "correctness": None, - "reason": None, - } - curr_res = self.check_correctness(problem, generation=response) - if curr_res: - response_entry["correctness"] = True - response_entry["reason"] = "" - else: - response_entry["correctness"] = False - response_entry["reason"] = "Solution is incorrect." - - return response_entry - - def make_conversations(self, data, system_prompt, model=None): - conversations = [] - for problem in data: - prompt_text = self.generate_prompt(problem) - conversations.append([ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt_text} - ]) - return conversations - - def load_and_filter_dataset(self, start, end, split="train", source=None, filter_difficulty=False, args=None): - dataset = load_dataset(self.dataset, "ARC-Challenge") - train_data = dataset[split].to_pandas() - return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] - - def process_remaining_data(self, train_data, results): - return [row.to_dict() for _, row in train_data.iterrows() if str(row["question"]) not in results] - - def get_answer(self, completion): - # First, we try to extract similar to MATH answers - answer = extract_answer(completion) - match = None - if answer: - # match for the letter answer needed. - match = self.letter_re.search(answer) - if match: - return match.group(1).strip() - - if not answer or not match: - # try basic-regex based search - patterns_to_remove = [ - ',', # Remove commas - r'\$', # Remove dollar signs - r'\.$' # Remove trailing period - r"\\", # Remove stray backslashes - r"\*", # Remove asterisks - ] - answer = completion - for pattern in patterns_to_remove: - answer = re.sub(pattern, '', answer) - matches = self.ans_re.findall(answer) - if not matches: - return self.invalid_ans - return matches[-1].strip() - - -TASK_HANDLERS = { - "NUMINA": NUMINATaskHandler, - "APPS": APPSTaskHandler, - "TACO": TACOTaskHandler, - "MATH500": MATH500TaskHandler, - "AIME": AIMETaskHandler, - "GPQADiamond": GPQADiamondTaskHandler, - "MMLU": MMLUTaskHandler, - "MMLUPro": MMLUProTaskHandler, - "LiveCodeBench": LiveCodeBenchTaskHandler, - "GSM8K": GSM8KTaskHandler, - "ARC-C": ARCChallengeTaskHandler, - "AMC23": AMC23TaskHandler, -} diff --git a/skythought/tools/util/task_handlers.py b/skythought/tools/util/task_handlers.py deleted file mode 100644 index 3cb6895..0000000 --- a/skythought/tools/util/task_handlers.py +++ /dev/null @@ -1,1084 +0,0 @@ -import copy -import json -import multiprocessing -import os -import random -import re -from multiprocessing import Manager -from typing import Any, Dict - -import numpy as np -from datasets import load_dataset - -from tasks.apps.apps_util import run_test as apps_run_test -from tasks.livecodebench.livecodebench_util import ( - map_to_example, - post_process_code, - translate_private_test_cases, - unsafe_lcb_runTests, -) -from tasks.taco.taco_util import run_test as taco_run_test -from util.model_utils import * - -from .common import TimeoutException, timeout -from .math_parsing_util import ( - extract_answer, - get_multiple_choice_answer, - math_equal, - mmlu_pro_extract_answer, - strip_answer_string, -) - - -def has_code(response): - pattern = r"```(?:[a-zA-Z]*)\n(.*?)```" - # Use re.DOTALL to match multiline content inside backticks - matches = re.findall(pattern, response, re.DOTALL) - # print(matches) - return matches - - -class TaskHandler: - @staticmethod - def get_question_key(): - raise NotImplementedError("Subclasses should implement this method.") - - def check_correctness(self, problem, generation): - raise NotImplementedError("Subclasses should implement this method.") - - def update_results(self, problem, response): - raise NotImplementedError("Subclasses should implement this method.") - - def make_conversations(self, data, system_prompt, model=None): - raise NotImplementedError("Subclasses should implement this method.") - - def load_existing_results(self, result_file): - if not os.path.exists(result_file): - return {} - with open(result_file, "r", encoding="utf-8") as f: - records = json.load(f) - return records - - def load_and_filter_dataset( - self, start, end, split="train", source=None, filter_difficulty=False, args=None - ): - raise NotImplementedError("Subclasses should implement this method.") - - def process_remaining_data(self, train_data, results): - raise NotImplementedError("Subclasses should implement this method.") - - -class MathTaskHandler(TaskHandler): - @staticmethod - def generate_prompt(prompt): - return "Return your final response within \\boxed{{}}. " + prompt - - def check_correctness(self, problem, generation): - answer = strip_answer_string(problem["answer"]) - pred = extract_answer(generation) - # print(problem) - pred = strip_answer_string(pred) - return math_equal(pred, answer) - - def update_results(self, problem, response): - if not isinstance(response, str): - response = response.outputs[0].text.strip() - # Initialize the response structure - response_entry = { - "content": response, - "correctness": None, - "reason": None, - } - curr_res = self.check_correctness(problem, generation=response) - if curr_res: - response_entry["correctness"] = True - response_entry["reason"] = "" - else: - response_entry["correctness"] = False - response_entry["reason"] = "Solution is incorrect." - - return response_entry - - def make_conversations(self, data, system_prompt, model=None): - conversations = [] - for problem in data: - prompt_text = self.generate_prompt(problem["problem"]) - conversations.append( - [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt_text}, - ] - ) - return conversations - - def process_remaining_data(self, train_data, results): - return [ - row.to_dict() - for _, row in train_data.iterrows() - if str(row["problem"]) not in results - ] - - def load_and_filter_dataset( - self, start, end, split="test", source=None, filter_difficulty=False, args=None - ): - dataset = load_dataset(self.dataset) - train_data = dataset[split].to_pandas() - return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] - - -class MATH500TaskHandler(MathTaskHandler): - def __init__(self): - self.dataset = "qq8933/MATH500" - - @staticmethod - def get_question_key(): - return "problem" - - -class AIMETaskHandler(MathTaskHandler): - def __init__(self): - self.dataset = "AI-MO/aimo-validation-aime" - - @staticmethod - def generate_prompt(prompt, model): - if MODEL_TO_NAME[model] == "Sky-T1-32B-Preview": - return prompt + "\nReturn your final response within \\boxed{{}}" - else: - return "Return your final response within \\boxed{{}}. " + prompt - - @staticmethod - def get_question_key(): - return "problem" - - def make_conversations(self, data, system_prompt, model=None): - conversations = [] - for problem in data: - prompt_text = self.generate_prompt(problem["problem"], model) - conversations.append( - [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt_text}, - ] - ) - return conversations - - def load_and_filter_dataset( - self, start, end, split="train", source=None, filter_difficulty=False, args=None - ): - dataset = load_dataset(self.dataset) - train_data = dataset[split].to_pandas() - filtered_data = train_data[train_data["url"].str.contains("2024", na=False)] - return filtered_data.iloc[start:end] if end > 0 else filtered_data.iloc[start:] - - -class GPQADiamondTaskHandler(TaskHandler): - def __init__(self): - self.dataset = "Idavidrein/gpqa" - - @staticmethod - def generate_prompt(prompt): - return ( - "Return your final response within \\boxed{{}} and only include the letter choice (A, B, C, or D) as your final response. " - + prompt - ) - - @staticmethod - def get_question_key(): - return "Question" - - def update_results(self, problem, response): - if not isinstance(response, str): - response = response.outputs[0].text.strip() - # Initialize the response structure - response_entry = { - "content": response, - "correctness": None, - "reason": None, - } - curr_res = self.check_correctness(problem, generation=response) - if curr_res: - response_entry["correctness"] = True - response_entry["reason"] = "" - else: - response_entry["correctness"] = False - response_entry["reason"] = "Solution is incorrect." - - return response_entry - - def check_correctness(self, problem, generation): - pred = get_multiple_choice_answer(generation) - answer = problem["Answer"] - return answer == pred - - def get_multiple_choice_answers(self, data): - answers = [ - data["Correct Answer"], - data["Incorrect Answer 1"], - data["Incorrect Answer 2"], - data["Incorrect Answer 3"], - ] - random.shuffle(answers) - - # Map options to letters - options = ["A", "B", "C", "D"] - options_to_answers = { - letter: answer for letter, answer in zip(options, answers) - } - - # Format the options into the string - multiple_choice_string = ", ".join( - f"{letter}) {options_to_answers[letter]}" for letter in options - ) - - # Save the letter corresponding to the correct answer - correct_answer_letter = next( - letter - for letter, answer in options_to_answers.items() - if answer == data["Correct Answer"] - ) - - return multiple_choice_string, correct_answer_letter - - def make_conversations(self, data, system_prompt, model=None): - conversations = [] - for problem in data: - ( - multiple_choice_string, - correct_answer_letter, - ) = self.get_multiple_choice_answers(problem) - problem["Answer"] = correct_answer_letter - prompt_text = self.generate_prompt( - problem["Question"] + "\n" + multiple_choice_string - ) - conversations.append( - [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt_text}, - ] - ) - return conversations - - def load_and_filter_dataset( - self, start, end, split="train", source=None, filter_difficulty=False, args=None - ): - dataset = load_dataset(self.dataset, "gpqa_diamond") - train_data = dataset[split].to_pandas() - return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] - - def process_remaining_data(self, train_data, results): - return [ - row.to_dict() - for _, row in train_data.iterrows() - if str(row["Question"]) not in results - ] - - -class MMLUTaskHandler(TaskHandler): - def __init__(self): - self.dataset = "cais/mmlu" - - @staticmethod - def generate_prompt(prompt): - return "Return your final response within \\boxed{{}}. " + prompt - - @staticmethod - def get_question_key(): - return "question" - - def check_correctness(self, problem, generation): - pred = get_multiple_choice_answer(generation) - abcd = "ABCD" - answer = abcd[problem["answer"]] - return answer == pred - - def update_results(self, problem, response): - if not isinstance(response, str): - response = response.outputs[0].text.strip() - # Initialize the response structure - response_entry = { - "content": response, - "correctness": None, - "reason": None, - } - curr_res = self.check_correctness(problem, generation=response) - if curr_res: - response_entry["correctness"] = True - response_entry["reason"] = "" - else: - response_entry["correctness"] = False - response_entry["reason"] = "Solution is incorrect." - return response_entry - - def get_multiple_choice_answers(self, problem): - options = problem["choices"] - for i, (label, option) in enumerate(zip("ABCD", options)): - options[i] = f"({label}) {str(option).strip()}" - options = " ".join(options) - return f"Answer Choices: {options}" - - def make_conversations(self, data, system_prompt, model=None): - conversations = [] - for problem in data: - multiple_choice_string = self.get_multiple_choice_answers(problem) - prompt_text = self.generate_prompt( - problem["question"] + "\n" + multiple_choice_string - ) - conversations.append( - [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt_text}, - ] - ) - return conversations - - def process_remaining_data(self, train_data, results): - return [ - row.to_dict() - for _, row in train_data.iterrows() - if str(row["question"]) not in results - ] - - def load_and_filter_dataset( - self, start, end, split="test", source=None, filter_difficulty=False, args=None - ): - dataset = load_dataset(self.dataset, "all") - train_data = dataset[split].to_pandas() - return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] - - -class MMLUProTaskHandler(MMLUTaskHandler): - def __init__(self): - super().__init__() - self.dataset = "TIGER-Lab/MMLU-Pro" - self.choices = [ - "A", - "B", - "C", - "D", - "E", - "F", - "G", - "H", - "I", - "J", - "K", - "L", - "M", - "N", - "O", - "P", - ] - - @staticmethod - def generate_prompt(prompt): - return "Return your final response within \\boxed{{}}. " + prompt - - @staticmethod - def get_question_key(): - return "question" - - def check_correctness(self, problem, generation): - pred = mmlu_pro_extract_answer(generation) - answer = self.choices[problem["answer_index"]] - return answer == pred - - def get_multiple_choice_answers(self, problem): - options = problem["options"] - for i, (label, option) in enumerate(zip(self.choices[: len(options)], options)): - options[i] = f"({label}) {str(option).strip()}" - options = " ".join(options) - return f"Answer Choices: {options}" - - def load_and_filter_dataset( - self, start, end, split="test", source=None, filter_difficulty=False, args=None - ): - dataset = load_dataset(self.dataset, "default") - train_data = dataset[split].to_pandas() - return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] - - -class NUMINATaskHandler(TaskHandler): - @staticmethod - def get_question_key(): - return "problem" - - @staticmethod - def generate_prompt(prompt): - return "Return your final response within \\boxed{{}}. " + prompt - - @timeout(5) # Add timeout of 5 seconds - def check_correctness(self, problem, generation): - solution = extract_answer(problem["solution"]) - solution = strip_answer_string(solution) - pred = extract_answer(generation) - pred = strip_answer_string(pred) - return math_equal(pred, solution) - - def update_results(self, problem, response): - if not isinstance(response, str): - response = response.outputs[0].text.strip() - # Initialize the response structure - response_entry = { - "content": response, - "correctness": None, - "reason": None, - } - - try: - curr_res = self.check_correctness(problem, generation=response) - if curr_res: - response_entry["correctness"] = True - response_entry["reason"] = "" - else: - response_entry["correctness"] = False - response_entry["reason"] = "Solution is incorrect." - except TimeoutException as e: - response_entry["correctness"] = False - response_entry["reason"] = str(e) - - return response_entry - - @staticmethod - def get_difficulty_dict(source, start, end): - diff_dict = {} - dataset = load_dataset( - "NovaSky-AI/labeled_numina_difficulty_859K", - trust_remote_code=True, - split="train", - ) - for example in dataset: - # print(example) - diff_dict[example["problem"]] = example["gpt_difficulty_parsed"] - return diff_dict - - def make_conversations(self, data, system_prompt, model=None): - conversations = [] - for problem in data: - prompt_text = self.generate_prompt(problem["problem"]) - conversations.append( - [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt_text}, - ] - ) - return conversations - - def load_and_filter_dataset( - self, start, end, split="train", source=None, filter_difficulty=False, args=None - ): - dataset = load_dataset("AI-MO/NuminaMath-CoT") - train_data = dataset[split].to_pandas() - train_data = ( - train_data.query("source == @source").iloc[start:end] - if end > 0 - else train_data.query("source == @source").iloc[start:] - ) - train_data = train_data[train_data["solution"].str.contains("boxed", na=False)] - if filter_difficulty: - diff_dict = self.get_difficulty_dict(source, start, end) - train_data = train_data[ - train_data["problem"] - .map(diff_dict) - .apply( - lambda x: x >= args.math_difficulty_lower_bound - and x <= args.math_difficulty_upper_bound - ) - ] - return train_data - - def process_remaining_data(self, train_data, results): - return [ - row.to_dict() - for _, row in train_data.iterrows() - if str(row["problem"]) not in results - ] - - -class APPSTaskHandler(TaskHandler): - @staticmethod - def get_question_key(): - return "question" - - @staticmethod - def generate_prompt(test_case, prompt, starter_code=None): - _input = "" - data = test_case - if not data.get("fn_name"): - _input += "Generate an executable Python function generated from the given prompt. The function should take stdin as input and print the output. Simply call the function after the definition." # "\nUse Standard Input format"#\n" - else: - _input += "Generate an executable Python function generated from the given prompt. Return the function body without invoking it at the final solution." # "\nUse Call-Based format"#\n" - data = prompt - _input += data - if starter_code != None: - data = starter_code - data = "\n" + data # + "\n" - _input += data - else: - # _input += "\n\n" - pass - - return _input - - def check_correctness(self, problem, generation): - TIMEOUT = 10 - - def _temp_run(problem, generation, debug, result): - try: - result.append( - apps_run_test(problem=problem, test=generation, debug=debug) - ) - except Exception: - pass - - manager = Manager() - result = manager.list() - p = multiprocessing.Process( - target=_temp_run, args=(problem, generation, False, result) - ) - p.start() - p.join(timeout=TIMEOUT + 1) - if p.is_alive(): - p.kill() - return bool(result and np.all(result[0])) - - def update_results(self, problem, response): - if not isinstance(response, str): - response = response.outputs[0].text.strip() - # Initialize the response structure - response_entry = { - "content": response, - "correctness": None, - "reason": None, - } - code_filter_result = has_code(response) - if len(code_filter_result) == 0: - response_entry["correctness"] = False - response_entry["reason"] = "Does not contain code component." - else: - last_code = code_filter_result[-1] - problem_to_check = copy.deepcopy(problem) - problem_to_check["input_output"] = json.loads(problem["input_output"]) - try: - problem_to_check["solutions"] = json.loads(problem["solutions"]) - except: - problem_to_check["solutions"] = "" - print("Empty solution from the dataset") - curr_res = self.check_correctness(problem_to_check, generation=last_code) - if curr_res: - response_entry["correctness"] = True - response_entry["reason"] = "" - else: - response_entry["correctness"] = False - response_entry["reason"] = "Code is incorrect." - - return response_entry - - def make_conversations(self, data, system_prompt, model=None): - conversations = [] - for problem in data: - test_case = json.loads(problem["input_output"]) - starter_code = problem["starter_code"] - prompt_text = self.generate_prompt( - test_case, problem["question"], starter_code - ) - conversations.append( - [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt_text}, - ] - ) - return conversations - - def load_and_filter_dataset( - self, start, end, split="train", source=None, filter_difficulty=False, args=None - ): - dataset = load_dataset("codeparrot/apps", trust_remote_code=True) - train_data = dataset[split].to_pandas() - if not filter_difficulty: - return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] - return ( - train_data.query("difficulty == @source").iloc[start:end] - if end > 0 - else train_data.query("difficulty == @source").iloc[start:] - ) - - def process_remaining_data(self, train_data, results): - return [ - row.to_dict() - for _, row in train_data.iterrows() - if str(row["question"]) not in results - ] - - -class TACOTaskHandler(TaskHandler): - @staticmethod - def get_question_key(): - return "question" - - @staticmethod - def generate_prompt(prompt, starter_code=None, fn_name=None): - _input = "\nQUESTION:\n" - _input += prompt - if starter_code: - _input += starter_code - if (not fn_name) and (not starter_code): - call_format = "\nUse Standard Input format" - _input += call_format - else: - call_format = "\nUse Call-Based format" - _input += call_format - _input += "\nANSWER:\n" - - return _input - - def check_correctness(self, problem, generation): - TIME_OUT = 300 - - def _temp_run(problem, generation, debug, result): - try: - result.append(taco_run_test(problem, test=generation, debug=debug)) - except Exception as e: - print(f"Error in _temp_run: {e}") - - manager = Manager() - result = manager.list() - p = multiprocessing.Process( - target=_temp_run, args=(problem, generation, False, result) - ) - p.start() - p.join(timeout=TIME_OUT + 1) - if p.is_alive(): - p.kill() - return bool(result and np.all(result[0])) - - def update_results(self, problem, response): - if not isinstance(response, str): - response = response.outputs[0].text.strip() - # Initialize the response structure - response_entry = { - "content": response, - "correctness": None, - "reason": None, - } - code_filter_result = has_code(response) - if len(code_filter_result) == 0: - response_entry["correctness"] = False - response_entry["reason"] = "Does not contain code component." - else: - last_code = code_filter_result[-1] - curr_res = self.check_correctness(problem, generation=last_code) - if curr_res: - response_entry["correctness"] = True - response_entry["reason"] = "" - else: - response_entry["correctness"] = False - response_entry["reason"] = "Code is incorrect." - - return response_entry - - def make_conversations(self, data, system_prompt, model=None): - conversations = [] - for idx, problem in enumerate(data): - starter_code = ( - None if len(problem["starter_code"]) == 0 else problem["starter_code"] - ) - try: - input_outpout = json.loads(problem["input_output"]) - fn_name = ( - None - if not input_outpout.get("fn_name") - else input_outpout["fn_name"] - ) - except ValueError: - fn_name = None - prompt_text = self.generate_prompt( - problem["question"], starter_code, fn_name - ) - conversations.append( - [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt_text}, - ] - ) - return conversations - - def load_and_filter_dataset( - self, start, end, split="train", source=None, filter_difficulty=False, args=None - ): - dataset = load_dataset("BAAI/TACO", "ALL", trust_remote_code=True) - train_data = dataset[split].to_pandas() - if not filter_difficulty: - return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] - return ( - train_data.query("difficulty == @source").iloc[start:end] - if end > 0 - else train_data.query("difficulty == @source").iloc[start:] - ) - - def process_remaining_data(self, train_data, results): - return [ - row.to_dict() - for _, row in train_data.iterrows() - if str(row["question"]) not in results - ] - - -class LiveCodeBenchTaskHandler(TaskHandler): - @staticmethod - def generate_prompt(problem): - # print(problem) - prompt = problem["prompt"] - if problem["is_stdin"]: - return ( - "Generate an executable Python function generated from the given prompt. The function should take stdin as input and print the output. Simply call the function after the definition." - + prompt - ) - else: - return ( - "Generate an executable Python function generated from the given prompt. Return the function body without invoking it at the final solution." - + prompt - ) - - @staticmethod - def get_question_key(): - return "task_id" - - def check_correctness( - self, - problem: Dict, - completion: str, - timeout: float, - runtime_debug=False, - is_extracted=False, - ) -> Dict: - """ - Evaluates the functional correctness of a completion by running the test - suite provided in the problem. - - :param completion_id: an optional completion ID so we can match - the results later even if execution finishes asynchronously. - """ - result_list = unsafe_lcb_runTests( - problem, completion, timeout, runtime_debug, is_extracted - ) - details = [r[0] for r in result_list] - all_passed = all(details) - - result = "" - if result_list and all_passed: - result = "passed" - - return result == "passed" - - def update_results(self, problem, response): - if not isinstance(response, str): - response = response.outputs[0].text.strip() - # Initialize the response structure - response_entry = { - "content": response, - "correctness": None, - "reason": None, - } - code_filter_result = has_code(response) - # print(response) - if len(code_filter_result) == 0: - response_entry["correctness"] = False - response_entry["reason"] = "Does not contain code component." - else: - last_code = code_filter_result[-1] - problem_to_check = copy.deepcopy(problem) - - curr_res = self.check_correctness( - problem=problem_to_check, - completion=post_process_code(last_code), - timeout=6, - is_extracted=not problem_to_check["is_stdin"], - ) - if curr_res: - response_entry["correctness"] = True - response_entry["reason"] = "" - else: - response_entry["correctness"] = False - response_entry["reason"] = "Code is incorrect." - - return response_entry - - def make_conversations(self, data, system_prompt, model=None): - conversations = [] - for problem in data: - prompt_text = self.generate_prompt(problem) - conversations.append( - [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt_text}, - ] - ) - return conversations - - def load_and_filter_dataset( - self, start, end, split="test", source=None, filter_difficulty=False, args=None - ): - dataset = load_dataset( - "livecodebench/code_generation_lite", - version_tag="release_v2", - split=split, - trust_remote_code=True, - ) - if filter_difficulty: - dataset = dataset.filter(lambda example: example["difficulty"] == source) - dataset = dataset.map( - lambda example: { - "private_test_cases": translate_private_test_cases( - example["private_test_cases"] - ) - } - ) - # Apply the mapping function - dataset = dataset.map( - map_to_example, remove_columns=dataset.column_names - ).to_pandas() - return dataset.iloc[start:end] if end > 0 else dataset.iloc[start:] - - def process_remaining_data(self, train_data, results): - return [ - row.to_dict() - for _, row in train_data.iterrows() - if str(row["task_id"]) not in results - ] - - -class GSM8KTaskHandler(TaskHandler): - def __init__(self) -> None: - super().__init__() - self.dataset = "openai/gsm8k" - self.ans_re = re.compile(r"((-?[$0-9.,]{2,})|(-?[0-9]+))") - self.gt_re = re.compile(r"#### (\-?[0-9\.\,]+)") - self.invalid_ans = "[invalid]" - - @staticmethod - def get_question_key(): - return "question" - - @staticmethod - def generate_prompt(problem): - question = problem["question"] - full_prompt = f'Given the following problem, reason and give a final answer to the problem.\nProblem: {question}\nYour response should end with "The final answer is [answer]" where [answer] is the response to the problem.' - return full_prompt - - def check_correctness(self, problem: Dict[str, Any], generation: str) -> bool: - gt_answer = self.extract_gt_answer(problem["answer"]) - model_answer = extract_answer(generation) - model_answer = self.sanitize_answer(model_answer) - return model_answer == gt_answer - - def update_results(self, problem, response): - if not isinstance(response, str): - response = response.outputs[0].text.strip() - # Initialize the response structure - response_entry = { - "content": response, - "correctness": None, - "reason": None, - } - curr_res = self.check_correctness(problem, generation=response) - if curr_res: - response_entry["correctness"] = True - response_entry["reason"] = "" - else: - response_entry["correctness"] = False - response_entry["reason"] = "Solution is incorrect." - - return response_entry - - def make_conversations(self, data, system_prompt, model=None): - conversations = [] - for problem in data: - prompt_text = self.generate_prompt(problem) - conversations.append( - [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt_text}, - ] - ) - return conversations - - def load_and_filter_dataset( - self, start, end, split="train", source=None, filter_difficulty=False, args=None - ): - dataset = load_dataset(self.dataset, "main") - train_data = dataset[split].to_pandas() - return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] - - def process_remaining_data(self, train_data, results): - return [ - row.to_dict() - for _, row in train_data.iterrows() - if str(row["question"]) not in results - ] - - def extract_gt_answer(self, completion): - match = self.gt_re.search(completion) - if match: - match_str = match.group(1).strip() - match_str = match_str.replace(",", "") - return match_str - else: - return self.invalid_ans - - def sanitize_answer(self, answer): - patterns_to_remove = [ - ",", # Remove commas - r"\$", # Remove dollar signs - r"\.$" r"\*", # Remove trailing period # Remove asterisks - ] - for pattern in patterns_to_remove: - answer = re.sub(pattern, "", answer) - - matches = self.ans_re.findall(answer) - if matches: - # get the last match (i.e final response) and the first / outer capturing group - match_str = matches[-1][0].strip() - return match_str - else: - return self.invalid_ans - - -class ARCChallengeTaskHandler(TaskHandler): - def __init__(self) -> None: - super().__init__() - self.dataset = "allenai/ai2_arc" - self.ans_re = re.compile(r"[Tt]he best answer is ([A-D])[\.\,]*", re.IGNORECASE) - self.letter_re = re.compile(r"([A-D])[\.\,]*") - self.canonical_options = ["A", "B", "C", "D"] - self.invalid_ans = "[invalid]" - - @staticmethod - def get_question_key(): - return "question" - - @staticmethod - def generate_prompt(problem): - question = problem["question"] - choices = problem["choices"] - choices_text = "\n".join( - [ - f"{label}.{choice}" - for label, choice in zip(["A", "B", "C", "D"], choices["text"]) - ] - ) - full_prompt = ( - 'Given the following question and four candidate answers (A, B, C and D), choose the best answer. Your response should end with "The best answer is [the_answer_letter]" where [the_answer_letter] is one of the four letter choice (A, B, C, or D).\n' - + f"{question}\n{choices_text}" - ) - return full_prompt - - def check_correctness(self, problem: Dict[str, Any], generation: str) -> bool: - gt_answer = problem["answerKey"] - if gt_answer not in self.canonical_options: - gt_answer = self.canonical_options[int(problem["answerKey"]) - 1] - model_answer = self.get_answer(generation) - return model_answer == gt_answer - - def update_results(self, problem, response): - if not isinstance(response, str): - response = response.outputs[0].text.strip() - # Initialize the response structure - response_entry = { - "content": response, - "correctness": None, - "reason": None, - } - curr_res = self.check_correctness(problem, generation=response) - if curr_res: - response_entry["correctness"] = True - response_entry["reason"] = "" - else: - response_entry["correctness"] = False - response_entry["reason"] = "Solution is incorrect." - - return response_entry - - def make_conversations(self, data, system_prompt, model=None): - conversations = [] - for problem in data: - prompt_text = self.generate_prompt(problem) - conversations.append( - [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt_text}, - ] - ) - return conversations - - def load_and_filter_dataset( - self, start, end, split="train", source=None, filter_difficulty=False, args=None - ): - dataset = load_dataset(self.dataset, "ARC-Challenge") - train_data = dataset[split].to_pandas() - return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] - - def process_remaining_data(self, train_data, results): - return [ - row.to_dict() - for _, row in train_data.iterrows() - if str(row["question"]) not in results - ] - - def get_answer(self, completion): - # First, we try to extract similar to MATH answers - answer = extract_answer(completion) - match = None - if answer: - # match for the letter answer needed. - match = self.letter_re.search(answer) - if match: - return match.group(1).strip() - - if not answer or not match: - # try basic-regex based search - patterns_to_remove = [ - ",", # Remove commas - r"\$", # Remove dollar signs - r"\.$" r"\\", # Remove trailing period # Remove stray backslashes - r"\*", # Remove asterisks - ] - answer = completion - for pattern in patterns_to_remove: - answer = re.sub(pattern, "", answer) - matches = self.ans_re.findall(answer) - if not matches: - return self.invalid_ans - return matches[-1].strip() - - -class AMC23TaskHandler(MathTaskHandler): - def __init__(self): - self.dataset = "AI-MO/aimo-validation-amc" - - @staticmethod - def get_question_key(): - return "problem" - - def load_and_filter_dataset( - self, start, end, split="train", source=None, filter_difficulty=False, args=None - ): - dataset = load_dataset(self.dataset) - train_data = dataset[split].to_pandas() - filtered_data = train_data[train_data["url"].str.contains("2023", na=False)] - return filtered_data.iloc[start:end] if end > 0 else filtered_data.iloc[start:] - - -TASK_HANDLERS = { - "NUMINA": NUMINATaskHandler, - "APPS": APPSTaskHandler, - "TACO": TACOTaskHandler, - "MATH500": MATH500TaskHandler, - "AIME": AIMETaskHandler, - "GPQADiamond": GPQADiamondTaskHandler, - "MMLU": MMLUTaskHandler, - "MMLUPro": MMLUProTaskHandler, - "LiveCodeBench": LiveCodeBenchTaskHandler, - "GSM8K": GSM8KTaskHandler, - "ARC-C": ARCChallengeTaskHandler, - "AMC23": AMC23TaskHandler, -} From faaf2938b69f423df7f2ebc43cfc271008e6d044 Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Fri, 24 Jan 2025 04:16:43 +0000 Subject: [PATCH 08/47] x Signed-off-by: SumanthRH --- skythought/tools/inference_and_check.py | 3 +- skythought/tools/tasks/__init__.py | 6 +- skythought/tools/tasks/aime/aime_handler.py | 8 +- skythought/tools/tasks/amc23/amc23_handler.py | 22 +++--- skythought/tools/tasks/apps/apps_handler.py | 7 +- skythought/tools/tasks/arc/arc_handler.py | 3 +- skythought/tools/tasks/common.py | 27 +++++-- skythought/tools/tasks/gsm8k/gsm8k_handler.py | 68 ++++++++-------- .../livecodebench/livecodebench_handler.py | 3 +- skythought/tools/tasks/math/math500.yaml | 11 +++ .../tools/tasks/math/math500_handler.py | 10 --- skythought/tools/tasks/math/math_handler.py | 57 ++++++-------- skythought/tools/tasks/mmlu/mmlu_handler.py | 77 +++++++++++-------- .../tools/tasks/numina/numina_handler.py | 60 ++++++++++----- skythought/tools/tasks/taco/taco_handler.py | 5 +- skythought/tools/util/common.py | 9 +++ 16 files changed, 207 insertions(+), 169 deletions(-) delete mode 100644 skythought/tools/tasks/math/math500_handler.py diff --git a/skythought/tools/inference_and_check.py b/skythought/tools/inference_and_check.py index 816477e..bdc429e 100644 --- a/skythought/tools/inference_and_check.py +++ b/skythought/tools/inference_and_check.py @@ -436,6 +436,7 @@ def main(): ], help="Dataset to process.", ) + parser.add_argument("--config", type=str, help="Path to the config file.") parser.add_argument( "--model", type=str, @@ -492,7 +493,7 @@ def main(): ) args = parser.parse_args() - handler: TaskHandler = TASK_HANDLERS[args.dataset]() + handler: TaskHandler = TASK_HANDLERS[args.dataset](args.config) temperatures = [1] if args.model.startswith("openai/o1") else args.temperatures print(f"Temperature: {temperatures}") diff --git a/skythought/tools/tasks/__init__.py b/skythought/tools/tasks/__init__.py index 8dc2388..b20b406 100644 --- a/skythought/tools/tasks/__init__.py +++ b/skythought/tools/tasks/__init__.py @@ -6,7 +6,7 @@ from .gpqa_diamond.gpqa_diamond_handler import GPQADiamondTaskHandler from .gsm8k.gsm8k_handler import GSM8KTaskHandler from .livecodebench.livecodebench_handler import LiveCodeBenchTaskHandler -from .math.math_handler import MATH500TaskHandler, MathTaskHandler +from .math.math_handler import MathTaskHandler from .mmlu.mmlu_handler import MMLUProTaskHandler, MMLUTaskHandler from .numina.numina_handler import NUMINATaskHandler from .taco.taco_handler import TACOTaskHandler @@ -15,7 +15,7 @@ "NUMINA": NUMINATaskHandler, "APPS": APPSTaskHandler, "TACO": TACOTaskHandler, - "MATH500": MATH500TaskHandler, + "MATH500": MathTaskHandler, "AIME": AIMETaskHandler, "GPQADiamond": GPQADiamondTaskHandler, "MMLU": MMLUTaskHandler, @@ -31,7 +31,7 @@ AIMETaskHandler, APPSTaskHandler, TACOTaskHandler, - MATH500TaskHandler, + MathTaskHandler, AMC23TaskHandler, NUMINATaskHandler, GPQADiamondTaskHandler, diff --git a/skythought/tools/tasks/aime/aime_handler.py b/skythought/tools/tasks/aime/aime_handler.py index c0504d8..2ec7014 100644 --- a/skythought/tools/tasks/aime/aime_handler.py +++ b/skythought/tools/tasks/aime/aime_handler.py @@ -1,6 +1,6 @@ from datasets import load_dataset -from tasks.common import MathTaskHandler +from tasks.math.math_handler import MathTaskHandler from util.model_utils import MODEL_TO_NAME @@ -8,8 +8,7 @@ class AIMETaskHandler(MathTaskHandler): def __init__(self): self.dataset = "AI-MO/aimo-validation-aime" - @staticmethod - def generate_prompt(prompt, model): + def generate_prompt(self, prompt, model): if MODEL_TO_NAME[model] == "Sky-T1-32B-Preview": return prompt + "\nReturn your final response within \\boxed{{}}" else: @@ -38,6 +37,3 @@ def load_and_filter_dataset( train_data = dataset[split].to_pandas() filtered_data = train_data[train_data["url"].str.contains("2024", na=False)] return filtered_data.iloc[start:end] if end > 0 else filtered_data.iloc[start:] - - def dummy(): - raise NotImplementedError() diff --git a/skythought/tools/tasks/amc23/amc23_handler.py b/skythought/tools/tasks/amc23/amc23_handler.py index fbacdb2..878e7d8 100644 --- a/skythought/tools/tasks/amc23/amc23_handler.py +++ b/skythought/tools/tasks/amc23/amc23_handler.py @@ -1,13 +1,7 @@ from datasets import load_dataset -from typing import Dict, Any -from multiprocessing import Manager -from tasks.apps.apps_util import run_test as apps_run_test -from tasks.taco.taco_util import run_test as taco_run_test -from util.math_parsing_util import strip_answer_string, get_multiple_choice_answer, extract_answer, math_equal, mmlu_pro_extract_answer -from tasks.livecodebench.livecodebench_util import unsafe_lcb_runTests, map_to_example, has_test_type, post_process_code, translate_private_test_cases -from util.common import TimeoutException, timeout -from util.model_utils import SYSTEM_PROMPT, MODEL_TO_NAME -from tasks.common import MathTaskHandler + +from tasks.math.math_handler import MathTaskHandler + class AMC23TaskHandler(MathTaskHandler): def __init__(self): @@ -16,9 +10,11 @@ def __init__(self): @staticmethod def get_question_key(): return "problem" - - def load_and_filter_dataset(self, start, end, split="train", source=None, filter_difficulty=False, args=None): + + def load_and_filter_dataset( + self, start, end, split="train", source=None, filter_difficulty=False, args=None + ): dataset = load_dataset(self.dataset) train_data = dataset[split].to_pandas() - filtered_data = train_data[train_data['url'].str.contains("2023", na=False)] - return filtered_data.iloc[start:end] if end > 0 else filtered_data.iloc[start:] \ No newline at end of file + filtered_data = train_data[train_data["url"].str.contains("2023", na=False)] + return filtered_data.iloc[start:end] if end > 0 else filtered_data.iloc[start:] diff --git a/skythought/tools/tasks/apps/apps_handler.py b/skythought/tools/tasks/apps/apps_handler.py index 603191c..d32e69e 100644 --- a/skythought/tools/tasks/apps/apps_handler.py +++ b/skythought/tools/tasks/apps/apps_handler.py @@ -17,8 +17,7 @@ class APPSTaskHandler(TaskHandler): def get_question_key(): return "question" - @staticmethod - def generate_prompt(test_case, prompt, starter_code=None): + def generate_prompt(self, test_case, prompt, starter_code=None): _input = "" data = test_case if not data.get("fn_name"): @@ -27,7 +26,7 @@ def generate_prompt(test_case, prompt, starter_code=None): _input += "Generate an executable Python function generated from the given prompt. Return the function body without invoking it at the final solution." # "\nUse Call-Based format"#\n" data = prompt _input += data - if starter_code != None: + if starter_code is not None: data = starter_code data = "\n" + data # + "\n" _input += data @@ -78,7 +77,7 @@ def update_results(self, problem, response): problem_to_check["input_output"] = json.loads(problem["input_output"]) try: problem_to_check["solutions"] = json.loads(problem["solutions"]) - except: + except Exception: problem_to_check["solutions"] = "" print("Empty solution from the dataset") curr_res = self.check_correctness(problem_to_check, generation=last_code) diff --git a/skythought/tools/tasks/arc/arc_handler.py b/skythought/tools/tasks/arc/arc_handler.py index 3c6c3aa..6caa056 100644 --- a/skythought/tools/tasks/arc/arc_handler.py +++ b/skythought/tools/tasks/arc/arc_handler.py @@ -20,8 +20,7 @@ def __init__(self) -> None: def get_question_key(): return "question" - @staticmethod - def generate_prompt(problem): + def generate_prompt(self, problem): question = problem["question"] choices = problem["choices"] choices_text = "\n".join( diff --git a/skythought/tools/tasks/common.py b/skythought/tools/tasks/common.py index cd6e88a..40306eb 100644 --- a/skythought/tools/tasks/common.py +++ b/skythought/tools/tasks/common.py @@ -1,17 +1,30 @@ import json import os -import re +from typing import Any, Dict, List, Optional +import yaml +from pydantic import BaseModel, Field -def has_code(response): - pattern = r"```(?:[a-zA-Z]*)\n(.*?)```" - # Use re.DOTALL to match multiline content inside backticks - matches = re.findall(pattern, response, re.DOTALL) - # print(matches) - return matches + +class TaskConfig(BaseModel): + dataset_name: str + dataset_source: Optional[str] = None + question_key: str + templating_parameters: Dict[str, str] = Field(default_factory=dict) + fewshot_config: List[Dict[str, Any]] = Field(default_factory=list) + num_fewshot: int = 0 class TaskHandler: + def __init__(self, yaml_file_path): + self.yaml_file_path = yaml_file_path + self.task_config = TaskConfig(**self.load_yaml(yaml_file_path)) + + @staticmethod + def load_yaml(yaml_file_path): + with open(yaml_file_path, "r", encoding="utf-8") as f: + return yaml.safe_load(f) + @staticmethod def get_question_key(): raise NotImplementedError("Subclasses should implement this method.") diff --git a/skythought/tools/tasks/gsm8k/gsm8k_handler.py b/skythought/tools/tasks/gsm8k/gsm8k_handler.py index 07357cd..c10d144 100644 --- a/skythought/tools/tasks/gsm8k/gsm8k_handler.py +++ b/skythought/tools/tasks/gsm8k/gsm8k_handler.py @@ -1,14 +1,10 @@ -import re +import re +from typing import Any, Dict + from datasets import load_dataset -from typing import Dict, Any -from multiprocessing import Manager -from tasks.apps.apps_util import run_test as apps_run_test -from tasks.taco.taco_util import run_test as taco_run_test -from util.math_parsing_util import strip_answer_string, get_multiple_choice_answer, extract_answer, math_equal, mmlu_pro_extract_answer -from tasks.livecodebench.livecodebench_util import unsafe_lcb_runTests, map_to_example, has_test_type, post_process_code, translate_private_test_cases -from util.common import TimeoutException, timeout -from util.model_utils import SYSTEM_PROMPT, MODEL_TO_NAME + from tasks.common import TaskHandler +from util.math_parsing_util import extract_answer class GSM8KTaskHandler(TaskHandler): @@ -16,25 +12,24 @@ def __init__(self) -> None: super().__init__() self.dataset = "openai/gsm8k" self.ans_re = re.compile(r"((-?[$0-9.,]{2,})|(-?[0-9]+))") - self.gt_re = re.compile(r"#### (\-?[0-9\.\,]+)") + self.gt_re = re.compile(r"#### (\-?[0-9\.\,]+)") self.invalid_ans = "[invalid]" @staticmethod def get_question_key(): return "question" - @staticmethod - def generate_prompt(problem): - question = problem["question"] - full_prompt = f"Given the following problem, reason and give a final answer to the problem.\nProblem: {question}\nYour response should end with \"The final answer is [answer]\" where [answer] is the response to the problem." + def generate_prompt(self, problem): + question = problem["question"] + full_prompt = f'Given the following problem, reason and give a final answer to the problem.\nProblem: {question}\nYour response should end with "The final answer is [answer]" where [answer] is the response to the problem.' return full_prompt - - def check_correctness(self, problem: Dict[str, Any], generation: str) -> bool: + + def check_correctness(self, problem: Dict[str, Any], generation: str) -> bool: gt_answer = self.extract_gt_answer(problem["answer"]) model_answer = extract_answer(generation) model_answer = self.sanitize_answer(model_answer) return model_answer == gt_answer - + def update_results(self, problem, response): if not isinstance(response, str): response = response.outputs[0].text.strip() @@ -44,34 +39,42 @@ def update_results(self, problem, response): "correctness": None, "reason": None, } - curr_res= self.check_correctness(problem, generation=response) + curr_res = self.check_correctness(problem, generation=response) if curr_res: response_entry["correctness"] = True response_entry["reason"] = "" else: response_entry["correctness"] = False response_entry["reason"] = "Solution is incorrect." - + return response_entry def make_conversations(self, data, system_prompt, model=None): conversations = [] for problem in data: prompt_text = self.generate_prompt(problem) - conversations.append([ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt_text} - ]) + conversations.append( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt_text}, + ] + ) return conversations - def load_and_filter_dataset(self, start, end, split="train", source=None, filter_difficulty=False, args=None): + def load_and_filter_dataset( + self, start, end, split="train", source=None, filter_difficulty=False, args=None + ): dataset = load_dataset(self.dataset, "main") train_data = dataset[split].to_pandas() return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] def process_remaining_data(self, train_data, results): - return [row.to_dict() for _, row in train_data.iterrows() if str(row["question"]) not in results] - + return [ + row.to_dict() + for _, row in train_data.iterrows() + if str(row["question"]) not in results + ] + def extract_gt_answer(self, completion): match = self.gt_re.search(completion) if match: @@ -83,18 +86,17 @@ def extract_gt_answer(self, completion): def sanitize_answer(self, answer): patterns_to_remove = [ - ',', # Remove commas - r'\$', # Remove dollar signs - r'\.$' # Remove trailing period - r"\*", # Remove asterisks + ",", # Remove commas + r"\$", # Remove dollar signs + r"\.$" r"\*", # Remove trailing period # Remove asterisks ] for pattern in patterns_to_remove: - answer = re.sub(pattern, '', answer) - + answer = re.sub(pattern, "", answer) + matches = self.ans_re.findall(answer) if matches: # get the last match (i.e final response) and the first / outer capturing group match_str = matches[-1][0].strip() return match_str else: - return self.invalid_ans \ No newline at end of file + return self.invalid_ans diff --git a/skythought/tools/tasks/livecodebench/livecodebench_handler.py b/skythought/tools/tasks/livecodebench/livecodebench_handler.py index ee51746..317a0f8 100644 --- a/skythought/tools/tasks/livecodebench/livecodebench_handler.py +++ b/skythought/tools/tasks/livecodebench/livecodebench_handler.py @@ -14,8 +14,7 @@ class LiveCodeBenchTaskHandler(TaskHandler): - @staticmethod - def generate_prompt(problem): + def generate_prompt(self, problem): # print(problem) prompt = problem["prompt"] if problem["is_stdin"]: diff --git a/skythought/tools/tasks/math/math500.yaml b/skythought/tools/tasks/math/math500.yaml index e69de29..ba45bf2 100644 --- a/skythought/tools/tasks/math/math500.yaml +++ b/skythought/tools/tasks/math/math500.yaml @@ -0,0 +1,11 @@ +dataset_name: "qq8933/MATH500" # repo ID in huggingface +dataset_source: null # which subset on huggingface +question_key: problem +split: test +templating_parameters: + - instruction: "Return your final response within \\boxed{{}}. " + # optional. Not supported yet. +fewshot_config: + - question: ... + - target: ... +num_fewshot: 0 diff --git a/skythought/tools/tasks/math/math500_handler.py b/skythought/tools/tasks/math/math500_handler.py deleted file mode 100644 index 00571bb..0000000 --- a/skythought/tools/tasks/math/math500_handler.py +++ /dev/null @@ -1,10 +0,0 @@ -from datasets import load_dataset -from typing import Dict, Any -from multiprocessing import Manager -from tasks.apps.apps_util import run_test as apps_run_test -from tasks.taco.taco_util import run_test as taco_run_test -from util.math_parsing_util import strip_answer_string, get_multiple_choice_answer, extract_answer, math_equal, mmlu_pro_extract_answer -from tasks.livecodebench.livecodebench_util import unsafe_lcb_runTests, map_to_example, has_test_type, post_process_code, translate_private_test_cases -from util.common import TimeoutException, timeout -from util.model_utils import SYSTEM_PROMPT, MODEL_TO_NAME -from tasks.common import MathTaskHandler diff --git a/skythought/tools/tasks/math/math_handler.py b/skythought/tools/tasks/math/math_handler.py index 82f41c1..791b908 100644 --- a/skythought/tools/tasks/math/math_handler.py +++ b/skythought/tools/tasks/math/math_handler.py @@ -1,27 +1,19 @@ from datasets import load_dataset -from typing import Dict, Any -from multiprocessing import Manager -from tasks.apps.apps_util import run_test as apps_run_test -from tasks.taco.taco_util import run_test as taco_run_test -from util.math_parsing_util import strip_answer_string, get_multiple_choice_answer, extract_answer, math_equal, mmlu_pro_extract_answer -from tasks.livecodebench.livecodebench_util import unsafe_lcb_runTests, map_to_example, has_test_type, post_process_code, translate_private_test_cases -from util.common import TimeoutException, timeout -from util.model_utils import SYSTEM_PROMPT, MODEL_TO_NAME + from tasks.common import TaskHandler +from util.math_parsing_util import extract_answer, math_equal, strip_answer_string class MathTaskHandler(TaskHandler): - @staticmethod - def generate_prompt(prompt): - return "Return your final response within \\boxed{{}}. " + prompt - + def generate_prompt(self, prompt): + return self.task_config.templating_parameters["instruction"] + prompt + def check_correctness(self, problem, generation): answer = strip_answer_string(problem["answer"]) pred = extract_answer(generation) - # print(problem) pred = strip_answer_string(pred) return math_equal(pred, answer) - + def update_results(self, problem, response): if not isinstance(response, str): response = response.outputs[0].text.strip() @@ -38,32 +30,31 @@ def update_results(self, problem, response): else: response_entry["correctness"] = False response_entry["reason"] = "Solution is incorrect." - + return response_entry - + def make_conversations(self, data, system_prompt, model=None): conversations = [] for problem in data: - prompt_text = self.generate_prompt(problem["problem"]) - conversations.append([ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt_text} - ]) + prompt_text = self.generate_prompt(problem[self.task_config.question_key]) + conversations.append( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt_text}, + ] + ) return conversations - + def process_remaining_data(self, train_data, results): - return [row.to_dict() for _, row in train_data.iterrows() if str(row["problem"]) not in results] + return [ + row.to_dict() + for _, row in train_data.iterrows() + if str(row[self.task_config.question_key]) not in results + ] - def load_and_filter_dataset(self, start, end, split="test", source=None, filter_difficulty=False, args=None): + def load_and_filter_dataset( + self, start, end, split="test", source=None, filter_difficulty=False, args=None + ): dataset = load_dataset(self.dataset) train_data = dataset[split].to_pandas() return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] - - -class MATH500TaskHandler(MathTaskHandler): - def __init__(self): - self.dataset = "qq8933/MATH500" - - @staticmethod - def get_question_key(): - return "problem" diff --git a/skythought/tools/tasks/mmlu/mmlu_handler.py b/skythought/tools/tasks/mmlu/mmlu_handler.py index bd293f0..1f7cd53 100644 --- a/skythought/tools/tasks/mmlu/mmlu_handler.py +++ b/skythought/tools/tasks/mmlu/mmlu_handler.py @@ -1,28 +1,15 @@ -import copy -import json -import multiprocessing -import os -import random -import re -import numpy as np from datasets import load_dataset -from typing import Dict, Any -from multiprocessing import Manager -from tasks.apps.apps_util import run_test as apps_run_test -from tasks.taco.taco_util import run_test as taco_run_test -from util.math_parsing_util import strip_answer_string, get_multiple_choice_answer, extract_answer, math_equal, mmlu_pro_extract_answer -from tasks.livecodebench.livecodebench_util import unsafe_lcb_runTests, map_to_example, has_test_type, post_process_code, translate_private_test_cases -from util.common import TimeoutException, timeout -from util.model_utils import SYSTEM_PROMPT + +from util.math_parsing_util import get_multiple_choice_answer, mmlu_pro_extract_answer from ..common import TaskHandler + class MMLUTaskHandler(TaskHandler): def __init__(self): self.dataset = "cais/mmlu" - @staticmethod - def generate_prompt(prompt): + def generate_prompt(self, prompt): return "Return your final response within \\boxed{{}}. " + prompt @staticmethod @@ -52,40 +39,66 @@ def update_results(self, problem, response): response_entry["correctness"] = False response_entry["reason"] = "Solution is incorrect." return response_entry - + def get_multiple_choice_answers(self, problem): options = problem["choices"] for i, (label, option) in enumerate(zip("ABCD", options)): options[i] = f"({label}) {str(option).strip()}" options = " ".join(options) return f"Answer Choices: {options}" - + def make_conversations(self, data, system_prompt, model=None): conversations = [] for problem in data: multiple_choice_string = self.get_multiple_choice_answers(problem) - prompt_text = self.generate_prompt(problem["question"] + "\n" + multiple_choice_string) - conversations.append([ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt_text} - ]) + prompt_text = self.generate_prompt( + problem["question"] + "\n" + multiple_choice_string + ) + conversations.append( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt_text}, + ] + ) return conversations def process_remaining_data(self, train_data, results): - return [row.to_dict() for _, row in train_data.iterrows() if str(row["question"]) not in results] + return [ + row.to_dict() + for _, row in train_data.iterrows() + if str(row["question"]) not in results + ] - def load_and_filter_dataset(self, start, end, split="test", source=None, filter_difficulty=False, args=None): + def load_and_filter_dataset( + self, start, end, split="test", source=None, filter_difficulty=False, args=None + ): dataset = load_dataset(self.dataset, "all") train_data = dataset[split].to_pandas() return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] - class MMLUProTaskHandler(MMLUTaskHandler): def __init__(self): super().__init__() self.dataset = "TIGER-Lab/MMLU-Pro" - self.choices = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P"] + self.choices = [ + "A", + "B", + "C", + "D", + "E", + "F", + "G", + "H", + "I", + "J", + "K", + "L", + "M", + "N", + "O", + "P", + ] @staticmethod def generate_prompt(prompt): @@ -102,12 +115,14 @@ def check_correctness(self, problem, generation): def get_multiple_choice_answers(self, problem): options = problem["options"] - for i, (label, option) in enumerate(zip(self.choices[:len(options)], options)): + for i, (label, option) in enumerate(zip(self.choices[: len(options)], options)): options[i] = f"({label}) {str(option).strip()}" options = " ".join(options) return f"Answer Choices: {options}" - def load_and_filter_dataset(self, start, end, split="test", source=None, filter_difficulty=False, args=None): + def load_and_filter_dataset( + self, start, end, split="test", source=None, filter_difficulty=False, args=None + ): dataset = load_dataset(self.dataset, "default") train_data = dataset[split].to_pandas() - return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] \ No newline at end of file + return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] diff --git a/skythought/tools/tasks/numina/numina_handler.py b/skythought/tools/tasks/numina/numina_handler.py index a086ae3..d1093c5 100644 --- a/skythought/tools/tasks/numina/numina_handler.py +++ b/skythought/tools/tasks/numina/numina_handler.py @@ -1,23 +1,18 @@ from datasets import load_dataset -from typing import Dict, Any -from multiprocessing import Manager -from tasks.apps.apps_util import run_test as apps_run_test -from tasks.taco.taco_util import run_test as taco_run_test -from util.math_parsing_util import strip_answer_string, get_multiple_choice_answer, extract_answer, math_equal, mmlu_pro_extract_answer -from tasks.livecodebench.livecodebench_util import unsafe_lcb_runTests, map_to_example, has_test_type, post_process_code, translate_private_test_cases -from util.common import TimeoutException, timeout -from util.model_utils import SYSTEM_PROMPT, MODEL_TO_NAME + from tasks.common import TaskHandler +from util.common import TimeoutException, timeout +from util.math_parsing_util import extract_answer, math_equal, strip_answer_string + class NUMINATaskHandler(TaskHandler): @staticmethod def get_question_key(): return "problem" - @staticmethod - def generate_prompt(prompt): + def generate_prompt(self, prompt): return "Return your final response within \\boxed{{}}. " + prompt - + @timeout(5) # Add timeout of 5 seconds def check_correctness(self, problem, generation): solution = extract_answer(problem["solution"]) @@ -25,7 +20,7 @@ def check_correctness(self, problem, generation): pred = extract_answer(generation) pred = strip_answer_string(pred) return math_equal(pred, solution) - + def update_results(self, problem, response): if not isinstance(response, str): response = response.outputs[0].text.strip() @@ -53,7 +48,11 @@ def update_results(self, problem, response): @staticmethod def get_difficulty_dict(source, start, end): diff_dict = {} - dataset = load_dataset("NovaSky-AI/labeled_numina_difficulty_859K", trust_remote_code=True, split="train") + dataset = load_dataset( + "NovaSky-AI/labeled_numina_difficulty_859K", + trust_remote_code=True, + split="train", + ) for example in dataset: # print(example) diff_dict[example["problem"]] = example["gpt_difficulty_parsed"] @@ -63,21 +62,40 @@ def make_conversations(self, data, system_prompt, model=None): conversations = [] for problem in data: prompt_text = self.generate_prompt(problem["problem"]) - conversations.append([ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt_text} - ]) + conversations.append( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt_text}, + ] + ) return conversations - def load_and_filter_dataset(self, start, end, split="train", source=None, filter_difficulty=False, args=None): + def load_and_filter_dataset( + self, start, end, split="train", source=None, filter_difficulty=False, args=None + ): dataset = load_dataset("AI-MO/NuminaMath-CoT") train_data = dataset[split].to_pandas() - train_data = train_data.query('source == @source').iloc[start:end] if end > 0 else train_data.query('source == @source').iloc[start:] + train_data = ( + train_data.query("source == @source").iloc[start:end] + if end > 0 + else train_data.query("source == @source").iloc[start:] + ) train_data = train_data[train_data["solution"].str.contains("boxed", na=False)] if filter_difficulty: diff_dict = self.get_difficulty_dict(source, start, end) - train_data = train_data[train_data["problem"].map(diff_dict).apply(lambda x: x >= args.math_difficulty_lower_bound and x <= args.math_difficulty_upper_bound)] + train_data = train_data[ + train_data["problem"] + .map(diff_dict) + .apply( + lambda x: x >= args.math_difficulty_lower_bound + and x <= args.math_difficulty_upper_bound + ) + ] return train_data def process_remaining_data(self, train_data, results): - return [row.to_dict() for _, row in train_data.iterrows() if str(row["problem"]) not in results] \ No newline at end of file + return [ + row.to_dict() + for _, row in train_data.iterrows() + if str(row["problem"]) not in results + ] diff --git a/skythought/tools/tasks/taco/taco_handler.py b/skythought/tools/tasks/taco/taco_handler.py index 917207d..ccb1692 100644 --- a/skythought/tools/tasks/taco/taco_handler.py +++ b/skythought/tools/tasks/taco/taco_handler.py @@ -16,8 +16,7 @@ class TACOTaskHandler(TaskHandler): def get_question_key(): return "question" - @staticmethod - def generate_prompt(prompt, starter_code=None, fn_name=None): + def generate_prompt(self, prompt, starter_code=None, fn_name=None): _input = "\nQUESTION:\n" _input += prompt if starter_code: @@ -79,7 +78,7 @@ def update_results(self, problem, response): def make_conversations(self, data, system_prompt, model=None): conversations = [] - for idx, problem in enumerate(data): + for _, problem in enumerate(data): starter_code = ( None if len(problem["starter_code"]) == 0 else problem["starter_code"] ) diff --git a/skythought/tools/util/common.py b/skythought/tools/util/common.py index 1094cdf..e24bc23 100644 --- a/skythought/tools/util/common.py +++ b/skythought/tools/util/common.py @@ -1,4 +1,5 @@ import multiprocessing +import re class TimeoutException(Exception): @@ -44,3 +45,11 @@ def target(queue, *args, **kwargs): return wrapper return decorator + + +def has_code(response): + pattern = r"```(?:[a-zA-Z]*)\n(.*?)```" + # Use re.DOTALL to match multiline content inside backticks + matches = re.findall(pattern, response, re.DOTALL) + # print(matches) + return matches From c18ad7b6921b4634492efcc762ef81dea630065e Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Fri, 24 Jan 2025 08:52:21 +0000 Subject: [PATCH 09/47] extra large commit Signed-off-by: SumanthRH --- skythought/tools/inference_and_check.py | 6 +- skythought/tools/tasks/aime/aime.yaml | 7 ++ skythought/tools/tasks/aime/aime_handler.py | 21 +++-- skythought/tools/tasks/amc23/amc23.yaml | 6 ++ skythought/tools/tasks/amc23/amc23_handler.py | 12 +-- skythought/tools/tasks/apps/apps.yaml | 10 +++ skythought/tools/tasks/apps/apps_handler.py | 51 +++++------- skythought/tools/tasks/arc/arc_c.yaml | 7 ++ skythought/tools/tasks/arc/arc_handler.py | 27 +++---- skythought/tools/tasks/common.py | 25 +++++- .../tasks/gpqa_diamond/gpqa_diamond.yaml | 7 ++ .../gpqa_diamond/gpqa_diamond_handler.py | 78 ++++++++++--------- skythought/tools/tasks/gsm8k/gsm8k.yaml | 7 ++ skythought/tools/tasks/gsm8k/gsm8k_handler.py | 15 +--- .../tasks/livecodebench/livecodebench.yaml | 11 +++ .../livecodebench/livecodebench_handler.py | 50 ++++++------ skythought/tools/tasks/math/math500.yaml | 4 +- skythought/tools/tasks/math/math_handler.py | 17 ++-- skythought/tools/tasks/mmlu/mmlu.yaml | 5 ++ skythought/tools/tasks/mmlu/mmlu_handler.py | 28 ++----- skythought/tools/tasks/mmlu/mmlu_pro.yaml | 5 ++ skythought/tools/tasks/numina/numina.yaml | 6 ++ .../tools/tasks/numina/numina_handler.py | 29 ++++--- skythought/tools/tasks/taco/taco.yaml | 14 ++++ skythought/tools/tasks/taco/taco_handler.py | 53 +++++++------ 25 files changed, 277 insertions(+), 224 deletions(-) create mode 100644 skythought/tools/tasks/arc/arc_c.yaml create mode 100644 skythought/tools/tasks/gpqa_diamond/gpqa_diamond.yaml create mode 100644 skythought/tools/tasks/gsm8k/gsm8k.yaml create mode 100644 skythought/tools/tasks/mmlu/mmlu.yaml create mode 100644 skythought/tools/tasks/mmlu/mmlu_pro.yaml create mode 100644 skythought/tools/tasks/numina/numina.yaml diff --git a/skythought/tools/inference_and_check.py b/skythought/tools/inference_and_check.py index bdc429e..fdb46d5 100644 --- a/skythought/tools/inference_and_check.py +++ b/skythought/tools/inference_and_check.py @@ -124,7 +124,7 @@ def perform_inference_and_check( total_correct += response_entry["correctness"] total_finish += 1 - problem_key = remaining_data[idx][handler.get_question_key()] + problem_key = remaining_data[idx][handler.question_key] if problem_key not in results: results[problem_key] = remaining_data[idx] if isinstance(handler, NUMINATaskHandler): @@ -211,7 +211,7 @@ def perform_check(handler: TaskHandler, temperatures, result_file, args): tasks = [] for item in remaining_data: - problem_key = item[handler.get_question_key()] + problem_key = item[handler.question_key] # If this item exists in the results file, check each temperature if problem_key in results and "responses" in results[problem_key]: for temp in temperatures: @@ -359,7 +359,7 @@ def perform_inference_and_save( completion_tokens.append(completion_token) problem_key = remaining_data[idx][ - handler.get_question_key() + handler.question_key ] # can you use this idx if problem_key not in results: results[problem_key] = remaining_data[idx] diff --git a/skythought/tools/tasks/aime/aime.yaml b/skythought/tools/tasks/aime/aime.yaml index e69de29..b2e1fcc 100644 --- a/skythought/tools/tasks/aime/aime.yaml +++ b/skythought/tools/tasks/aime/aime.yaml @@ -0,0 +1,7 @@ + +dataset_source: AI-MO/aimo-validation-aime +dataset_split: train +question_key: problem +templating_parameters: + regular_template: "Return your final response within \\boxed{{}}. {prompt}" + sky_template: "{prompt}\nReturn your final response within \\boxed{{}}" \ No newline at end of file diff --git a/skythought/tools/tasks/aime/aime_handler.py b/skythought/tools/tasks/aime/aime_handler.py index 2ec7014..9a488c1 100644 --- a/skythought/tools/tasks/aime/aime_handler.py +++ b/skythought/tools/tasks/aime/aime_handler.py @@ -1,5 +1,3 @@ -from datasets import load_dataset - from tasks.math.math_handler import MathTaskHandler from util.model_utils import MODEL_TO_NAME @@ -10,18 +8,18 @@ def __init__(self): def generate_prompt(self, prompt, model): if MODEL_TO_NAME[model] == "Sky-T1-32B-Preview": - return prompt + "\nReturn your final response within \\boxed{{}}" + return self.task_config.templating_parameters["sky_template"].format( + prompt=prompt + ) else: - return "Return your final response within \\boxed{{}}. " + prompt - - @staticmethod - def get_question_key(): - return "problem" + return self.task_config.templating_parameters["regular_template"].format( + prompt=prompt + ) def make_conversations(self, data, system_prompt, model=None): conversations = [] for problem in data: - prompt_text = self.generate_prompt(problem["problem"], model) + prompt_text = self.generate_prompt(problem[self.question_key], model) conversations.append( [ {"role": "system", "content": system_prompt}, @@ -31,9 +29,8 @@ def make_conversations(self, data, system_prompt, model=None): return conversations def load_and_filter_dataset( - self, start, end, split="train", source=None, filter_difficulty=False, args=None + self, start, end, split=None, source=None, filter_difficulty=False, args=None ): - dataset = load_dataset(self.dataset) - train_data = dataset[split].to_pandas() + train_data = self.load_dataset(source=source, split=split) filtered_data = train_data[train_data["url"].str.contains("2024", na=False)] return filtered_data.iloc[start:end] if end > 0 else filtered_data.iloc[start:] diff --git a/skythought/tools/tasks/amc23/amc23.yaml b/skythought/tools/tasks/amc23/amc23.yaml index e69de29..6430bf7 100644 --- a/skythought/tools/tasks/amc23/amc23.yaml +++ b/skythought/tools/tasks/amc23/amc23.yaml @@ -0,0 +1,6 @@ +dataset_name: AI-MO/aimo-validation-amc +dataset_kwargs: + trust_remote_code: true +dataset_split: train +question_key: problem +difficulty: null diff --git a/skythought/tools/tasks/amc23/amc23_handler.py b/skythought/tools/tasks/amc23/amc23_handler.py index 878e7d8..7087089 100644 --- a/skythought/tools/tasks/amc23/amc23_handler.py +++ b/skythought/tools/tasks/amc23/amc23_handler.py @@ -1,20 +1,10 @@ -from datasets import load_dataset - from tasks.math.math_handler import MathTaskHandler class AMC23TaskHandler(MathTaskHandler): - def __init__(self): - self.dataset = "AI-MO/aimo-validation-amc" - - @staticmethod - def get_question_key(): - return "problem" - def load_and_filter_dataset( self, start, end, split="train", source=None, filter_difficulty=False, args=None ): - dataset = load_dataset(self.dataset) - train_data = dataset[split].to_pandas() + train_data = self.load_dataset(source=source, split=split) filtered_data = train_data[train_data["url"].str.contains("2023", na=False)] return filtered_data.iloc[start:end] if end > 0 else filtered_data.iloc[start:] diff --git a/skythought/tools/tasks/apps/apps.yaml b/skythought/tools/tasks/apps/apps.yaml index e69de29..a4f27f7 100644 --- a/skythought/tools/tasks/apps/apps.yaml +++ b/skythought/tools/tasks/apps/apps.yaml @@ -0,0 +1,10 @@ +dataset_name: codeparrot/apps +dataset_kwargs: + trust_remote_code: true +dataset_split: train +question_key: question +difficulty: null +templating_parameters: + with_fn_name_instruction: "Generate an executable Python function generated from the given prompt. The function should take stdin as input and print the output. Simply call the function after the definition." + without_fn_name_instruction: "Generate an executable Python function generated from the given prompt. Return the function body without invoking it at the final solution." + with_starter_code_template: "{input}\n{starter_code}" \ No newline at end of file diff --git a/skythought/tools/tasks/apps/apps_handler.py b/skythought/tools/tasks/apps/apps_handler.py index d32e69e..5f33647 100644 --- a/skythought/tools/tasks/apps/apps_handler.py +++ b/skythought/tools/tasks/apps/apps_handler.py @@ -4,7 +4,6 @@ from multiprocessing import Manager import numpy as np -from datasets import load_dataset from tasks.apps.apps_util import run_test as apps_run_test from util.common import has_code @@ -13,27 +12,20 @@ class APPSTaskHandler(TaskHandler): - @staticmethod - def get_question_key(): - return "question" - def generate_prompt(self, test_case, prompt, starter_code=None): - _input = "" - data = test_case - if not data.get("fn_name"): - _input += "Generate an executable Python function generated from the given prompt. The function should take stdin as input and print the output. Simply call the function after the definition." # "\nUse Standard Input format"#\n" + if not test_case.get("fn_name"): + _input = self.task_config.templating_parameters[ + "with_fn_name_instruction" + ] # "\nUse Standard Input format"#\n" else: - _input += "Generate an executable Python function generated from the given prompt. Return the function body without invoking it at the final solution." # "\nUse Call-Based format"#\n" - data = prompt - _input += data + _input = self.task_config.templating_parameters[ + "without_fn_name_instruction" + ] # "\nUse Call-Based format"#\n" + _input += prompt if starter_code is not None: - data = starter_code - data = "\n" + data # + "\n" - _input += data - else: - # _input += "\n\n" - pass - + _input = self.task_config.templating_parameters[ + "with_starter_code_template" + ].format(input=_input, starter_code=starter_code) return _input def check_correctness(self, problem, generation): @@ -107,21 +99,20 @@ def make_conversations(self, data, system_prompt, model=None): return conversations def load_and_filter_dataset( - self, start, end, split="train", source=None, filter_difficulty=False, args=None + self, start, end, split=None, source=None, filter_difficulty=False, args=None ): - dataset = load_dataset("codeparrot/apps", trust_remote_code=True) - train_data = dataset[split].to_pandas() - if not filter_difficulty: - return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] - return ( - train_data.query("difficulty == @source").iloc[start:end] - if end > 0 - else train_data.query("difficulty == @source").iloc[start:] - ) + train_data = self.load_dataset(source=source, split=split) + if filter_difficulty or self.task_config.difficulty: + difficulty = ( + self.task_config.difficulty if not filter_difficulty else source + ) + train_data = train_data.filter(lambda x: x["difficulty"] == difficulty) + + return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] def process_remaining_data(self, train_data, results): return [ row.to_dict() for _, row in train_data.iterrows() - if str(row["question"]) not in results + if str(row[self.question_key]) not in results ] diff --git a/skythought/tools/tasks/arc/arc_c.yaml b/skythought/tools/tasks/arc/arc_c.yaml new file mode 100644 index 0000000..2186578 --- /dev/null +++ b/skythought/tools/tasks/arc/arc_c.yaml @@ -0,0 +1,7 @@ +dataset_name: allenai/ai2_arc +dataset_source: ARC-Challenge +dataset_split: train +question_key: question +templating_parameters: + # We combine choices for a question into choices_text entry in the dataset + template: "Given the following question and four candidate answers (A, B, C and D), choose the best answer. Your response should end with \"The best answer is [the_answer_letter]\" where [the_answer_letter] is one of the four letter choice (A, B, C, or D).\n{question}\n{choices_text}" \ No newline at end of file diff --git a/skythought/tools/tasks/arc/arc_handler.py b/skythought/tools/tasks/arc/arc_handler.py index 6caa056..3947d25 100644 --- a/skythought/tools/tasks/arc/arc_handler.py +++ b/skythought/tools/tasks/arc/arc_handler.py @@ -1,37 +1,29 @@ import re from typing import Any, Dict -from datasets import load_dataset - from tasks.common import TaskHandler from util.math_parsing_util import extract_answer class ARCChallengeTaskHandler(TaskHandler): - def __init__(self) -> None: - super().__init__() - self.dataset = "allenai/ai2_arc" + def __init__(self, yaml_file_path) -> None: + super().__init__(yaml_file_path) self.ans_re = re.compile(r"[Tt]he best answer is ([A-D])[\.\,]*", re.IGNORECASE) self.letter_re = re.compile(r"([A-D])[\.\,]*") self.canonical_options = ["A", "B", "C", "D"] self.invalid_ans = "[invalid]" - @staticmethod - def get_question_key(): - return "question" - def generate_prompt(self, problem): - question = problem["question"] choices = problem["choices"] choices_text = "\n".join( [ f"{label}.{choice}" - for label, choice in zip(["A", "B", "C", "D"], choices["text"]) + for label, choice in zip(self.canonical_options, choices["text"]) ] ) - full_prompt = ( - 'Given the following question and four candidate answers (A, B, C and D), choose the best answer. Your response should end with "The best answer is [the_answer_letter]" where [the_answer_letter] is one of the four letter choice (A, B, C, or D).\n' - + f"{question}\n{choices_text}" + problem["choices_text"] = choices_text + full_prompt = self.task_config.templating_parameters["template"].format( + **problem ) return full_prompt @@ -74,17 +66,16 @@ def make_conversations(self, data, system_prompt, model=None): return conversations def load_and_filter_dataset( - self, start, end, split="train", source=None, filter_difficulty=False, args=None + self, start, end, split=None, source=None, filter_difficulty=False, args=None ): - dataset = load_dataset(self.dataset, "ARC-Challenge") - train_data = dataset[split].to_pandas() + train_data = self.load_dataset(source=source, split=split) return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] def process_remaining_data(self, train_data, results): return [ row.to_dict() for _, row in train_data.iterrows() - if str(row["question"]) not in results + if str(row[self.question_key]) not in results ] def get_answer(self, completion): diff --git a/skythought/tools/tasks/common.py b/skythought/tools/tasks/common.py index 40306eb..9800e89 100644 --- a/skythought/tools/tasks/common.py +++ b/skythought/tools/tasks/common.py @@ -2,32 +2,39 @@ import os from typing import Any, Dict, List, Optional +import pandas as pd import yaml +from datasets import load_dataset from pydantic import BaseModel, Field class TaskConfig(BaseModel): dataset_name: str dataset_source: Optional[str] = None + dataset_split: str + dataset_kwargs: Optional[Dict[str, Any]] = None question_key: str templating_parameters: Dict[str, str] = Field(default_factory=dict) + # Optional, unused for now fewshot_config: List[Dict[str, Any]] = Field(default_factory=list) num_fewshot: int = 0 class TaskHandler: + task_config_cls = TaskConfig + def __init__(self, yaml_file_path): self.yaml_file_path = yaml_file_path - self.task_config = TaskConfig(**self.load_yaml(yaml_file_path)) + self.task_config = self.task_config_cls(**self.load_yaml(yaml_file_path)) @staticmethod def load_yaml(yaml_file_path): with open(yaml_file_path, "r", encoding="utf-8") as f: return yaml.safe_load(f) - @staticmethod - def get_question_key(): - raise NotImplementedError("Subclasses should implement this method.") + @property + def question_key(self): + return self.task_config.question_key def check_correctness(self, problem, generation): raise NotImplementedError("Subclasses should implement this method.") @@ -45,6 +52,16 @@ def load_existing_results(self, result_file): records = json.load(f) return records + def load_dataset(self, source=None, split=None, **kwargs) -> pd.DataFrame: + dataset = load_dataset( + self.task_config.dataset_name, + source if source else self.task_config.dataset_source, + split=split if split else self.task_config.dataset_split, + **self.task_config.dataset_kwargs + ) + data = dataset.to_pandas() + return data + def load_and_filter_dataset( self, start, end, split="train", source=None, filter_difficulty=False, args=None ): diff --git a/skythought/tools/tasks/gpqa_diamond/gpqa_diamond.yaml b/skythought/tools/tasks/gpqa_diamond/gpqa_diamond.yaml new file mode 100644 index 0000000..4f577ab --- /dev/null +++ b/skythought/tools/tasks/gpqa_diamond/gpqa_diamond.yaml @@ -0,0 +1,7 @@ +dataset_name: Idavidrein/gpqa +dataset_source: gpqa_diamond +dataset_split: train +question_key: Question +templating_parameters: + # For GPQA, we combine the Question key and the multiple choice answers into a single `prompt` entry + template: "Return your final response within \\boxed{{}} and only include the letter choice (A, B, C, or D) as your final response. {prompt}" \ No newline at end of file diff --git a/skythought/tools/tasks/gpqa_diamond/gpqa_diamond_handler.py b/skythought/tools/tasks/gpqa_diamond/gpqa_diamond_handler.py index a4904fb..fb3ad30 100644 --- a/skythought/tools/tasks/gpqa_diamond/gpqa_diamond_handler.py +++ b/skythought/tools/tasks/gpqa_diamond/gpqa_diamond_handler.py @@ -1,28 +1,15 @@ -import copy -import json -import multiprocessing -import os import random -import re -import numpy as np -from datasets import load_dataset -from typing import Dict, Any -from multiprocessing import Manager -from util.model_utils import SYSTEM_PROMPT, MODEL_TO_NAME + from tasks.common import TaskHandler -from util.math_parsing_util import get_multiple_choice_answer, extract_answer, math_equal, mmlu_pro_extract_answer +from util.math_parsing_util import get_multiple_choice_answer + class GPQADiamondTaskHandler(TaskHandler): def __init__(self): self.dataset = "Idavidrein/gpqa" - @staticmethod - def generate_prompt(prompt): - return "Return your final response within \\boxed{{}} and only include the letter choice (A, B, C, or D) as your final response. " + prompt - - @staticmethod - def get_question_key(): - return "Question" + def generate_prompt(self, problem): + return self.task_config.templating_parameters["template"].format(**problem) def update_results(self, problem, response): if not isinstance(response, str): @@ -40,51 +27,70 @@ def update_results(self, problem, response): else: response_entry["correctness"] = False response_entry["reason"] = "Solution is incorrect." - + return response_entry - + def check_correctness(self, problem, generation): pred = get_multiple_choice_answer(generation) answer = problem["Answer"] return answer == pred - + def get_multiple_choice_answers(self, data): answers = [ data["Correct Answer"], data["Incorrect Answer 1"], data["Incorrect Answer 2"], - data["Incorrect Answer 3"] + data["Incorrect Answer 3"], ] random.shuffle(answers) # Map options to letters options = ["A", "B", "C", "D"] - options_to_answers = {letter: answer for letter, answer in zip(options, answers)} + options_to_answers = { + letter: answer for letter, answer in zip(options, answers) + } # Format the options into the string - multiple_choice_string = ", ".join(f"{letter}) {options_to_answers[letter]}" for letter in options) + multiple_choice_string = ", ".join( + f"{letter}) {options_to_answers[letter]}" for letter in options + ) # Save the letter corresponding to the correct answer - correct_answer_letter = next(letter for letter, answer in options_to_answers.items() if answer == data["Correct Answer"]) + correct_answer_letter = next( + letter + for letter, answer in options_to_answers.items() + if answer == data["Correct Answer"] + ) return multiple_choice_string, correct_answer_letter - + def make_conversations(self, data, system_prompt, model=None): conversations = [] for problem in data: - multiple_choice_string, correct_answer_letter = self.get_multiple_choice_answers(problem) + ( + multiple_choice_string, + correct_answer_letter, + ) = self.get_multiple_choice_answers(problem) problem["Answer"] = correct_answer_letter - prompt_text = self.generate_prompt(problem["Question"] + "\n" + multiple_choice_string) - conversations.append([ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt_text} - ]) + problem["prompt"] = problem["Question"] + "\n" + multiple_choice_string + prompt_text = self.generate_prompt(problem) + conversations.append( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt_text}, + ] + ) return conversations - def load_and_filter_dataset(self, start, end, split="train", source=None, filter_difficulty=False, args=None): - dataset = load_dataset(self.dataset, "gpqa_diamond") - train_data = dataset[split].to_pandas() + def load_and_filter_dataset( + self, start, end, split=None, source=None, filter_difficulty=False, args=None + ): + train_data = self.load_dataset(source=source, split=split) return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] def process_remaining_data(self, train_data, results): - return [row.to_dict() for _, row in train_data.iterrows() if str(row["Question"]) not in results] \ No newline at end of file + return [ + row.to_dict() + for _, row in train_data.iterrows() + if str(row["Question"]) not in results + ] diff --git a/skythought/tools/tasks/gsm8k/gsm8k.yaml b/skythought/tools/tasks/gsm8k/gsm8k.yaml new file mode 100644 index 0000000..bcd5920 --- /dev/null +++ b/skythought/tools/tasks/gsm8k/gsm8k.yaml @@ -0,0 +1,7 @@ +dataset_name: "openai/gsm8k" +dataset_source: main +dataset_split: test +question_key: question +templating_parameters: + template: "Given the following problem, reason and give a final answer to the problem.\nProblem: {question}\nYour response should end with \"The final answer is [answer]\" where [answer] is the response to the problem." + diff --git a/skythought/tools/tasks/gsm8k/gsm8k_handler.py b/skythought/tools/tasks/gsm8k/gsm8k_handler.py index c10d144..a954d71 100644 --- a/skythought/tools/tasks/gsm8k/gsm8k_handler.py +++ b/skythought/tools/tasks/gsm8k/gsm8k_handler.py @@ -1,8 +1,6 @@ import re from typing import Any, Dict -from datasets import load_dataset - from tasks.common import TaskHandler from util.math_parsing_util import extract_answer @@ -15,14 +13,8 @@ def __init__(self) -> None: self.gt_re = re.compile(r"#### (\-?[0-9\.\,]+)") self.invalid_ans = "[invalid]" - @staticmethod - def get_question_key(): - return "question" - def generate_prompt(self, problem): - question = problem["question"] - full_prompt = f'Given the following problem, reason and give a final answer to the problem.\nProblem: {question}\nYour response should end with "The final answer is [answer]" where [answer] is the response to the problem.' - return full_prompt + return self.task_config.templating_parameters["template"].format(**problem) def check_correctness(self, problem: Dict[str, Any], generation: str) -> bool: gt_answer = self.extract_gt_answer(problem["answer"]) @@ -62,10 +54,9 @@ def make_conversations(self, data, system_prompt, model=None): return conversations def load_and_filter_dataset( - self, start, end, split="train", source=None, filter_difficulty=False, args=None + self, start, end, split=None, source=None, filter_difficulty=False, args=None ): - dataset = load_dataset(self.dataset, "main") - train_data = dataset[split].to_pandas() + train_data = self.load_dataset(source=source, split=split) return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] def process_remaining_data(self, train_data, results): diff --git a/skythought/tools/tasks/livecodebench/livecodebench.yaml b/skythought/tools/tasks/livecodebench/livecodebench.yaml index e69de29..910debb 100644 --- a/skythought/tools/tasks/livecodebench/livecodebench.yaml +++ b/skythought/tools/tasks/livecodebench/livecodebench.yaml @@ -0,0 +1,11 @@ +dataset_name: "livecodebench/code_generation_lite" # repo ID in huggingface +dataset_source: null +dataset_split: test +dataset_kwargs: + version_tag: release_v2 + trust_remote_code: true +question_key: task_id +templating_parameters: + stdin_template: "Generate an executable Python function generated from the given prompt. The function should take stdin as input and print the output. Simply call the function after the definition. {prompt}" + non_stdin_template: "Generate an executable Python function generated from the given prompt. Return the function body without invoking it at the final solution. {prompt}" +difficulty: null # use all by default diff --git a/skythought/tools/tasks/livecodebench/livecodebench_handler.py b/skythought/tools/tasks/livecodebench/livecodebench_handler.py index 317a0f8..cebed95 100644 --- a/skythought/tools/tasks/livecodebench/livecodebench_handler.py +++ b/skythought/tools/tasks/livecodebench/livecodebench_handler.py @@ -1,9 +1,7 @@ import copy from typing import Dict -from datasets import load_dataset - -from tasks.common import TaskHandler +from tasks.common import TaskConfig, TaskHandler from tasks.livecodebench.livecodebench_util import ( map_to_example, post_process_code, @@ -13,25 +11,23 @@ from util.common import has_code +class LiveCodeBenchTaskConfig(TaskConfig): + difficulty: str = None # use all by default + + class LiveCodeBenchTaskHandler(TaskHandler): + task_config_cls = LiveCodeBenchTaskConfig + def generate_prompt(self, problem): - # print(problem) - prompt = problem["prompt"] if problem["is_stdin"]: - return ( - "Generate an executable Python function generated from the given prompt. The function should take stdin as input and print the output. Simply call the function after the definition." - + prompt + return self.task_config.templating_parameters["stdin_template"].format( + **problem ) else: - return ( - "Generate an executable Python function generated from the given prompt. Return the function body without invoking it at the final solution." - + prompt + return self.task_config.templating_parameters["non_stdin_template"].format( + **problem ) - @staticmethod - def get_question_key(): - return "task_id" - def check_correctness( self, problem: Dict, @@ -105,26 +101,26 @@ def make_conversations(self, data, system_prompt, model=None): return conversations def load_and_filter_dataset( - self, start, end, split="test", source=None, filter_difficulty=False, args=None + self, start, end, split=None, source=None, filter_difficulty=False, args=None ): - dataset = load_dataset( - "livecodebench/code_generation_lite", - version_tag="release_v2", - split=split, - trust_remote_code=True, - ) - if filter_difficulty: - dataset = dataset.filter(lambda example: example["difficulty"] == source) + dataset = self.load_dataset(source=source, split=split) + # Filter by CLI or config + if filter_difficulty or self.task_config.difficulty: + difficulty = source if filter_difficulty else self.task_config.difficulty + dataset = dataset.filter( + lambda example: example["difficulty"] == difficulty + ) dataset = dataset.map( lambda example: { "private_test_cases": translate_private_test_cases( example["private_test_cases"] ) - } + }, + writer_batch_size=100, ) # Apply the mapping function dataset = dataset.map( - map_to_example, remove_columns=dataset.column_names + map_to_example, remove_columns=dataset.column_names, writer_batch_size=100 ).to_pandas() return dataset.iloc[start:end] if end > 0 else dataset.iloc[start:] @@ -132,5 +128,5 @@ def process_remaining_data(self, train_data, results): return [ row.to_dict() for _, row in train_data.iterrows() - if str(row["task_id"]) not in results + if str(row[self.question_key]) not in results ] diff --git a/skythought/tools/tasks/math/math500.yaml b/skythought/tools/tasks/math/math500.yaml index ba45bf2..24c8927 100644 --- a/skythought/tools/tasks/math/math500.yaml +++ b/skythought/tools/tasks/math/math500.yaml @@ -1,9 +1,9 @@ dataset_name: "qq8933/MATH500" # repo ID in huggingface dataset_source: null # which subset on huggingface question_key: problem -split: test +dataset_split: test templating_parameters: - - instruction: "Return your final response within \\boxed{{}}. " + template: "Return your final response within \\boxed{{}}. {problem}" # optional. Not supported yet. fewshot_config: - question: ... diff --git a/skythought/tools/tasks/math/math_handler.py b/skythought/tools/tasks/math/math_handler.py index 791b908..3d5e226 100644 --- a/skythought/tools/tasks/math/math_handler.py +++ b/skythought/tools/tasks/math/math_handler.py @@ -1,12 +1,10 @@ -from datasets import load_dataset - from tasks.common import TaskHandler from util.math_parsing_util import extract_answer, math_equal, strip_answer_string class MathTaskHandler(TaskHandler): - def generate_prompt(self, prompt): - return self.task_config.templating_parameters["instruction"] + prompt + def generate_prompt(self, problem): + return self.task_config.templating_parameters["template"].format(**problem) def check_correctness(self, problem, generation): answer = strip_answer_string(problem["answer"]) @@ -36,7 +34,7 @@ def update_results(self, problem, response): def make_conversations(self, data, system_prompt, model=None): conversations = [] for problem in data: - prompt_text = self.generate_prompt(problem[self.task_config.question_key]) + prompt_text = self.generate_prompt(problem) conversations.append( [ {"role": "system", "content": system_prompt}, @@ -49,12 +47,11 @@ def process_remaining_data(self, train_data, results): return [ row.to_dict() for _, row in train_data.iterrows() - if str(row[self.task_config.question_key]) not in results + if str(row[self.question_key]) not in results ] def load_and_filter_dataset( - self, start, end, split="test", source=None, filter_difficulty=False, args=None + self, start, end, split=None, source=None, filter_difficulty=False, args=None ): - dataset = load_dataset(self.dataset) - train_data = dataset[split].to_pandas() - return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] + dataset = self.load_dataset(source=source, split=split) + return dataset.iloc[start:end] if end > 0 else dataset.iloc[start:] diff --git a/skythought/tools/tasks/mmlu/mmlu.yaml b/skythought/tools/tasks/mmlu/mmlu.yaml new file mode 100644 index 0000000..bae4caf --- /dev/null +++ b/skythought/tools/tasks/mmlu/mmlu.yaml @@ -0,0 +1,5 @@ +dataset_name: "cais/mmlu" +dataset_source: default +dataset_split: test +templating_parameters: + template: "Return your final response within \\boxed{{}}. {prompt}" diff --git a/skythought/tools/tasks/mmlu/mmlu_handler.py b/skythought/tools/tasks/mmlu/mmlu_handler.py index 1f7cd53..0739122 100644 --- a/skythought/tools/tasks/mmlu/mmlu_handler.py +++ b/skythought/tools/tasks/mmlu/mmlu_handler.py @@ -6,15 +6,8 @@ class MMLUTaskHandler(TaskHandler): - def __init__(self): - self.dataset = "cais/mmlu" - def generate_prompt(self, prompt): - return "Return your final response within \\boxed{{}}. " + prompt - - @staticmethod - def get_question_key(): - return "question" + return self.task_config.templating_parameters["template"].format(prompt=prompt) def check_correctness(self, problem, generation): pred = get_multiple_choice_answer(generation) @@ -78,9 +71,8 @@ def load_and_filter_dataset( class MMLUProTaskHandler(MMLUTaskHandler): - def __init__(self): - super().__init__() - self.dataset = "TIGER-Lab/MMLU-Pro" + def __init__(self, yaml_file_path): + super().__init__(yaml_file_path) self.choices = [ "A", "B", @@ -100,13 +92,8 @@ def __init__(self): "P", ] - @staticmethod - def generate_prompt(prompt): - return "Return your final response within \\boxed{{}}. " + prompt - - @staticmethod - def get_question_key(): - return "question" + def generate_prompt(self, prompt): + return self.task_config.templating_parameters["template"].format(prompt=prompt) def check_correctness(self, problem, generation): pred = mmlu_pro_extract_answer(generation) @@ -123,6 +110,5 @@ def get_multiple_choice_answers(self, problem): def load_and_filter_dataset( self, start, end, split="test", source=None, filter_difficulty=False, args=None ): - dataset = load_dataset(self.dataset, "default") - train_data = dataset[split].to_pandas() - return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] + dataset = self.load_dataset(source=source, split=split) + return dataset.iloc[start:end] if end > 0 else dataset.iloc[start:] diff --git a/skythought/tools/tasks/mmlu/mmlu_pro.yaml b/skythought/tools/tasks/mmlu/mmlu_pro.yaml new file mode 100644 index 0000000..7305cf3 --- /dev/null +++ b/skythought/tools/tasks/mmlu/mmlu_pro.yaml @@ -0,0 +1,5 @@ +dataset_name: TIGER-Lab/MMLU-Pro +dataset_source: default +dataset_split: test +templating_parameters: + template: "Return your final response within \\boxed{{}}. {prompt}" diff --git a/skythought/tools/tasks/numina/numina.yaml b/skythought/tools/tasks/numina/numina.yaml new file mode 100644 index 0000000..f9849b0 --- /dev/null +++ b/skythought/tools/tasks/numina/numina.yaml @@ -0,0 +1,6 @@ +dataset_name: "AI-MO/NuminaMath-CoT" +dataset_source: default +dataset_split: train +question_key: problem +templating_parameters: + template: "Return your final response within \\boxed{{}}. {prompt}" diff --git a/skythought/tools/tasks/numina/numina_handler.py b/skythought/tools/tasks/numina/numina_handler.py index d1093c5..2ddfa67 100644 --- a/skythought/tools/tasks/numina/numina_handler.py +++ b/skythought/tools/tasks/numina/numina_handler.py @@ -1,17 +1,19 @@ from datasets import load_dataset -from tasks.common import TaskHandler +from tasks.common import TaskConfig, TaskHandler from util.common import TimeoutException, timeout from util.math_parsing_util import extract_answer, math_equal, strip_answer_string +class NUMINATaskConfig(TaskConfig): + difficulty: str = None # use all by default + + class NUMINATaskHandler(TaskHandler): - @staticmethod - def get_question_key(): - return "problem" + task_config_cls = NUMINATaskConfig def generate_prompt(self, prompt): - return "Return your final response within \\boxed{{}}. " + prompt + return self.task_config.templating_parameters["template"].format(prompt=prompt) @timeout(5) # Add timeout of 5 seconds def check_correctness(self, problem, generation): @@ -73,25 +75,20 @@ def make_conversations(self, data, system_prompt, model=None): def load_and_filter_dataset( self, start, end, split="train", source=None, filter_difficulty=False, args=None ): - dataset = load_dataset("AI-MO/NuminaMath-CoT") - train_data = dataset[split].to_pandas() - train_data = ( - train_data.query("source == @source").iloc[start:end] - if end > 0 - else train_data.query("source == @source").iloc[start:] - ) - train_data = train_data[train_data["solution"].str.contains("boxed", na=False)] + dataset = self.load_dataset(source=source, split=split) + dataset = dataset.iloc[start:end] if end > 0 else dataset.iloc[start:] + dataset = dataset[dataset["solution"].str.contains("boxed", na=False)] if filter_difficulty: diff_dict = self.get_difficulty_dict(source, start, end) - train_data = train_data[ - train_data["problem"] + dataset = dataset[ + dataset["problem"] .map(diff_dict) .apply( lambda x: x >= args.math_difficulty_lower_bound and x <= args.math_difficulty_upper_bound ) ] - return train_data + return dataset def process_remaining_data(self, train_data, results): return [ diff --git a/skythought/tools/tasks/taco/taco.yaml b/skythought/tools/tasks/taco/taco.yaml index e69de29..4a7b6a1 100644 --- a/skythought/tools/tasks/taco/taco.yaml +++ b/skythought/tools/tasks/taco/taco.yaml @@ -0,0 +1,14 @@ +dataset_name: "BAAI/TACO" +dataset_source: default +dataset_split: ALL +dataset_kwargs: + trust_remote_code: true +templating_parameters: + initial_template: "\nQUESTION:\n{prompt}" + # Add starter code to initial template + starter_code_template: "{input}\n{starter_code}" + # stdin template is used when there is no starter code or fn_name + stdin_template: "{input}\nUse Standard Input format\nANSWER:\n" + # call template is used when there is starter code or fn_name + call_template: "{input}\nUse Call-Based format\nANSWER:\n" + diff --git a/skythought/tools/tasks/taco/taco_handler.py b/skythought/tools/tasks/taco/taco_handler.py index ccb1692..2b62a5b 100644 --- a/skythought/tools/tasks/taco/taco_handler.py +++ b/skythought/tools/tasks/taco/taco_handler.py @@ -3,31 +3,41 @@ from multiprocessing import Manager import numpy as np -from datasets import load_dataset from tasks.taco.taco_util import run_test as taco_run_test from util.common import has_code -from ..common import TaskHandler +from ..common import TaskConfig, TaskHandler + + +class TACOTaskConfig(TaskConfig): + difficulty: str = None # use all by default class TACOTaskHandler(TaskHandler): - @staticmethod - def get_question_key(): - return "question" + task_config_cls = TACOTaskConfig def generate_prompt(self, prompt, starter_code=None, fn_name=None): - _input = "\nQUESTION:\n" - _input += prompt + _input = self.task_config.templating_parameters["initial_template"].format( + prompt=prompt + ) + if starter_code: - _input += starter_code + _input = self.task_config.templating_parameters[ + "starter_code_template" + ].format(input=_input, starter_code=starter_code) + else: + _input = self.task_config.templating_parameters["initial_template"].format( + prompt=prompt + ) if (not fn_name) and (not starter_code): - call_format = "\nUse Standard Input format" - _input += call_format + _input = self.task_config.templating_parameters["stdin_template"].format( + input=_input + ) else: - call_format = "\nUse Call-Based format" - _input += call_format - _input += "\nANSWER:\n" + _input = self.task_config.templating_parameters["call_template"].format( + input=_input + ) return _input @@ -105,15 +115,14 @@ def make_conversations(self, data, system_prompt, model=None): def load_and_filter_dataset( self, start, end, split="train", source=None, filter_difficulty=False, args=None ): - dataset = load_dataset("BAAI/TACO", "ALL", trust_remote_code=True) - train_data = dataset[split].to_pandas() - if not filter_difficulty: - return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] - return ( - train_data.query("difficulty == @source").iloc[start:end] - if end > 0 - else train_data.query("difficulty == @source").iloc[start:] - ) + dataset = self.load_dataset(source=source, split=split) + if filter_difficulty or self.task_config.difficulty: + difficulty = source if filter_difficulty else self.task_config.difficulty + dataset = dataset.filter( + lambda example: example["difficulty"] == difficulty + ) + + return dataset.iloc[start:end] if end > 0 else dataset.iloc[start:] def process_remaining_data(self, train_data, results): return [ From b7d95322c4bceccbeec28fa7c2a36f3443e529b1 Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Fri, 24 Jan 2025 18:02:32 +0000 Subject: [PATCH 10/47] x Signed-off-by: SumanthRH --- skythought/tools/tasks/aime/aime_handler.py | 2 +- skythought/tools/tasks/amc23/amc23_handler.py | 4 ++-- skythought/tools/tasks/apps/apps.yaml | 5 +++-- skythought/tools/tasks/apps/apps_handler.py | 15 +++++++++------ skythought/tools/tasks/arc/arc_handler.py | 2 +- skythought/tools/tasks/common.py | 11 +++++------ .../tasks/gpqa_diamond/gpqa_diamond_handler.py | 2 +- skythought/tools/tasks/gsm8k/gsm8k_handler.py | 2 +- .../tasks/livecodebench/livecodebench_handler.py | 4 ++-- skythought/tools/tasks/numina/numina_handler.py | 6 ++++-- skythought/tools/tasks/taco/taco_handler.py | 5 +++-- 11 files changed, 32 insertions(+), 26 deletions(-) diff --git a/skythought/tools/tasks/aime/aime_handler.py b/skythought/tools/tasks/aime/aime_handler.py index 9a488c1..5935ef5 100644 --- a/skythought/tools/tasks/aime/aime_handler.py +++ b/skythought/tools/tasks/aime/aime_handler.py @@ -31,6 +31,6 @@ def make_conversations(self, data, system_prompt, model=None): def load_and_filter_dataset( self, start, end, split=None, source=None, filter_difficulty=False, args=None ): - train_data = self.load_dataset(source=source, split=split) + train_data = self.load_dataset(source=source, split=split).to_pandas() filtered_data = train_data[train_data["url"].str.contains("2024", na=False)] return filtered_data.iloc[start:end] if end > 0 else filtered_data.iloc[start:] diff --git a/skythought/tools/tasks/amc23/amc23_handler.py b/skythought/tools/tasks/amc23/amc23_handler.py index 7087089..5b765cc 100644 --- a/skythought/tools/tasks/amc23/amc23_handler.py +++ b/skythought/tools/tasks/amc23/amc23_handler.py @@ -3,8 +3,8 @@ class AMC23TaskHandler(MathTaskHandler): def load_and_filter_dataset( - self, start, end, split="train", source=None, filter_difficulty=False, args=None + self, start, end, split=None, source=None, filter_difficulty=False, args=None ): - train_data = self.load_dataset(source=source, split=split) + train_data = self.load_dataset(source=source, split=split).to_pandas() filtered_data = train_data[train_data["url"].str.contains("2023", na=False)] return filtered_data.iloc[start:end] if end > 0 else filtered_data.iloc[start:] diff --git a/skythought/tools/tasks/apps/apps.yaml b/skythought/tools/tasks/apps/apps.yaml index a4f27f7..4846504 100644 --- a/skythought/tools/tasks/apps/apps.yaml +++ b/skythought/tools/tasks/apps/apps.yaml @@ -5,6 +5,7 @@ dataset_split: train question_key: question difficulty: null templating_parameters: - with_fn_name_instruction: "Generate an executable Python function generated from the given prompt. The function should take stdin as input and print the output. Simply call the function after the definition." - without_fn_name_instruction: "Generate an executable Python function generated from the given prompt. Return the function body without invoking it at the final solution." + with_fn_name_template: "Generate an executable Python function generated from the given prompt. The function should take stdin as input and print the output. Simply call the function after the definition. {prompt}" + without_fn_name_template: "Generate an executable Python function generated from the given prompt. Return the function body without invoking it at the final solution. {prompt}" + # Add starter code on top of the initial template with_starter_code_template: "{input}\n{starter_code}" \ No newline at end of file diff --git a/skythought/tools/tasks/apps/apps_handler.py b/skythought/tools/tasks/apps/apps_handler.py index 5f33647..c06ac45 100644 --- a/skythought/tools/tasks/apps/apps_handler.py +++ b/skythought/tools/tasks/apps/apps_handler.py @@ -15,13 +15,16 @@ class APPSTaskHandler(TaskHandler): def generate_prompt(self, test_case, prompt, starter_code=None): if not test_case.get("fn_name"): _input = self.task_config.templating_parameters[ - "with_fn_name_instruction" - ] # "\nUse Standard Input format"#\n" + "with_fn_name_template" + ].format( + prompt=prompt + ) # "\nUse Standard Input format"#\n" else: _input = self.task_config.templating_parameters[ - "without_fn_name_instruction" - ] # "\nUse Call-Based format"#\n" - _input += prompt + "without_fn_name_template" + ].format( + prompt=prompt + ) # "\nUse Call-Based format"#\n" if starter_code is not None: _input = self.task_config.templating_parameters[ "with_starter_code_template" @@ -101,7 +104,7 @@ def make_conversations(self, data, system_prompt, model=None): def load_and_filter_dataset( self, start, end, split=None, source=None, filter_difficulty=False, args=None ): - train_data = self.load_dataset(source=source, split=split) + train_data = self.load_dataset(source=source, split=split).to_pandas() if filter_difficulty or self.task_config.difficulty: difficulty = ( self.task_config.difficulty if not filter_difficulty else source diff --git a/skythought/tools/tasks/arc/arc_handler.py b/skythought/tools/tasks/arc/arc_handler.py index 3947d25..eb2dc87 100644 --- a/skythought/tools/tasks/arc/arc_handler.py +++ b/skythought/tools/tasks/arc/arc_handler.py @@ -68,7 +68,7 @@ def make_conversations(self, data, system_prompt, model=None): def load_and_filter_dataset( self, start, end, split=None, source=None, filter_difficulty=False, args=None ): - train_data = self.load_dataset(source=source, split=split) + train_data = self.load_dataset(source=source, split=split).to_pandas() return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] def process_remaining_data(self, train_data, results): diff --git a/skythought/tools/tasks/common.py b/skythought/tools/tasks/common.py index 9800e89..47ef30d 100644 --- a/skythought/tools/tasks/common.py +++ b/skythought/tools/tasks/common.py @@ -2,8 +2,8 @@ import os from typing import Any, Dict, List, Optional -import pandas as pd import yaml +from datasets import Dataset as HFDataset from datasets import load_dataset from pydantic import BaseModel, Field @@ -52,15 +52,14 @@ def load_existing_results(self, result_file): records = json.load(f) return records - def load_dataset(self, source=None, split=None, **kwargs) -> pd.DataFrame: + def load_dataset(self, source=None, split=None, **kwargs) -> HFDataset: dataset = load_dataset( - self.task_config.dataset_name, - source if source else self.task_config.dataset_source, + path=self.task_config.dataset_name, + name=source if source else self.task_config.dataset_source, split=split if split else self.task_config.dataset_split, **self.task_config.dataset_kwargs ) - data = dataset.to_pandas() - return data + return dataset def load_and_filter_dataset( self, start, end, split="train", source=None, filter_difficulty=False, args=None diff --git a/skythought/tools/tasks/gpqa_diamond/gpqa_diamond_handler.py b/skythought/tools/tasks/gpqa_diamond/gpqa_diamond_handler.py index fb3ad30..50c5092 100644 --- a/skythought/tools/tasks/gpqa_diamond/gpqa_diamond_handler.py +++ b/skythought/tools/tasks/gpqa_diamond/gpqa_diamond_handler.py @@ -85,7 +85,7 @@ def make_conversations(self, data, system_prompt, model=None): def load_and_filter_dataset( self, start, end, split=None, source=None, filter_difficulty=False, args=None ): - train_data = self.load_dataset(source=source, split=split) + train_data = self.load_dataset(source=source, split=split).to_pandas() return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] def process_remaining_data(self, train_data, results): diff --git a/skythought/tools/tasks/gsm8k/gsm8k_handler.py b/skythought/tools/tasks/gsm8k/gsm8k_handler.py index a954d71..4ab52ff 100644 --- a/skythought/tools/tasks/gsm8k/gsm8k_handler.py +++ b/skythought/tools/tasks/gsm8k/gsm8k_handler.py @@ -56,7 +56,7 @@ def make_conversations(self, data, system_prompt, model=None): def load_and_filter_dataset( self, start, end, split=None, source=None, filter_difficulty=False, args=None ): - train_data = self.load_dataset(source=source, split=split) + train_data = self.load_dataset(source=source, split=split).to_pandas() return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] def process_remaining_data(self, train_data, results): diff --git a/skythought/tools/tasks/livecodebench/livecodebench_handler.py b/skythought/tools/tasks/livecodebench/livecodebench_handler.py index cebed95..56d2aea 100644 --- a/skythought/tools/tasks/livecodebench/livecodebench_handler.py +++ b/skythought/tools/tasks/livecodebench/livecodebench_handler.py @@ -1,5 +1,5 @@ import copy -from typing import Dict +from typing import Dict, Optional from tasks.common import TaskConfig, TaskHandler from tasks.livecodebench.livecodebench_util import ( @@ -12,7 +12,7 @@ class LiveCodeBenchTaskConfig(TaskConfig): - difficulty: str = None # use all by default + difficulty: Optional[str] = None # use all by default class LiveCodeBenchTaskHandler(TaskHandler): diff --git a/skythought/tools/tasks/numina/numina_handler.py b/skythought/tools/tasks/numina/numina_handler.py index 2ddfa67..b941220 100644 --- a/skythought/tools/tasks/numina/numina_handler.py +++ b/skythought/tools/tasks/numina/numina_handler.py @@ -1,3 +1,5 @@ +from typing import Optional + from datasets import load_dataset from tasks.common import TaskConfig, TaskHandler @@ -6,7 +8,7 @@ class NUMINATaskConfig(TaskConfig): - difficulty: str = None # use all by default + difficulty: Optional[str] = None # use all by default class NUMINATaskHandler(TaskHandler): @@ -75,7 +77,7 @@ def make_conversations(self, data, system_prompt, model=None): def load_and_filter_dataset( self, start, end, split="train", source=None, filter_difficulty=False, args=None ): - dataset = self.load_dataset(source=source, split=split) + dataset = self.load_dataset(source=source, split=split).to_pandas() dataset = dataset.iloc[start:end] if end > 0 else dataset.iloc[start:] dataset = dataset[dataset["solution"].str.contains("boxed", na=False)] if filter_difficulty: diff --git a/skythought/tools/tasks/taco/taco_handler.py b/skythought/tools/tasks/taco/taco_handler.py index 2b62a5b..fc014ee 100644 --- a/skythought/tools/tasks/taco/taco_handler.py +++ b/skythought/tools/tasks/taco/taco_handler.py @@ -1,6 +1,7 @@ import json import multiprocessing from multiprocessing import Manager +from typing import Optional import numpy as np @@ -11,7 +12,7 @@ class TACOTaskConfig(TaskConfig): - difficulty: str = None # use all by default + difficulty: Optional[str] = None # use all by default class TACOTaskHandler(TaskHandler): @@ -115,7 +116,7 @@ def make_conversations(self, data, system_prompt, model=None): def load_and_filter_dataset( self, start, end, split="train", source=None, filter_difficulty=False, args=None ): - dataset = self.load_dataset(source=source, split=split) + dataset = self.load_dataset(source=source, split=split).to_pandas() if filter_difficulty or self.task_config.difficulty: difficulty = source if filter_difficulty else self.task_config.difficulty dataset = dataset.filter( From 8b9c67ea2e9f5b63285e51310d595ac7bcb4d546 Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Fri, 24 Jan 2025 18:20:51 +0000 Subject: [PATCH 11/47] minor linting changes Signed-off-by: SumanthRH --- skythought/tools/.githooks/pre-commit | 6 ++++-- skythought/tools/format.sh | 2 ++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/skythought/tools/.githooks/pre-commit b/skythought/tools/.githooks/pre-commit index 2cc1e30..927ae65 100755 --- a/skythought/tools/.githooks/pre-commit +++ b/skythought/tools/.githooks/pre-commit @@ -3,6 +3,8 @@ set -e # Get tools directory path relative to git root TOOLS_DIR=$(git rev-parse --show-toplevel)/skythought/tools # Only run pre-commit if changes are in tools/ +# Run pre-commit from tools/ directory to use linting rules in this directory if git diff --cached --name-only | grep "^skythought/tools/"; then - pre-commit run --files $(git diff --cached --name-only | grep "^skythought/tools/") --config $TOOLS_DIR/.pre-commit-config.yaml -fi \ No newline at end of file + cd $TOOLS_DIR; + pre-commit run --files $(git diff --cached --name-only | grep "^skythought/tools/") --config .pre-commit-config.yaml +fi diff --git a/skythought/tools/format.sh b/skythought/tools/format.sh index 75d52f5..2296f90 100644 --- a/skythought/tools/format.sh +++ b/skythought/tools/format.sh @@ -1,3 +1,4 @@ + set -e # Get tools directory path relative to git root @@ -15,4 +16,5 @@ chmod +x $HOOK_SCRIPT git config --local core.hooksPath "$TOOLS_DIR/.githooks" # pre-commit run --all-files always runs from the root directory. we run this only on tools/ for now. +cd $TOOLS_DIR; pre-commit run --files $TOOLS_DIR/* --config $TOOLS_DIR/.pre-commit-config.yaml \ No newline at end of file From c767708229e6d067739f9097514fa86cc415601f Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Fri, 24 Jan 2025 19:04:33 +0000 Subject: [PATCH 12/47] minor Signed-off-by: SumanthRH --- skythought/tools/tasks/aime/aime_handler.py | 11 ++++------- skythought/tools/tasks/apps/apps_handler.py | 4 ++-- .../tools/tasks/gpqa_diamond/gpqa_diamond_handler.py | 2 -- skythought/tools/tasks/gsm8k/gsm8k_handler.py | 5 ++--- 4 files changed, 8 insertions(+), 14 deletions(-) diff --git a/skythought/tools/tasks/aime/aime_handler.py b/skythought/tools/tasks/aime/aime_handler.py index 5935ef5..4b11a19 100644 --- a/skythought/tools/tasks/aime/aime_handler.py +++ b/skythought/tools/tasks/aime/aime_handler.py @@ -3,23 +3,20 @@ class AIMETaskHandler(MathTaskHandler): - def __init__(self): - self.dataset = "AI-MO/aimo-validation-aime" - - def generate_prompt(self, prompt, model): + def generate_prompt(self, problem, model): if MODEL_TO_NAME[model] == "Sky-T1-32B-Preview": return self.task_config.templating_parameters["sky_template"].format( - prompt=prompt + **problem ) else: return self.task_config.templating_parameters["regular_template"].format( - prompt=prompt + **problem ) def make_conversations(self, data, system_prompt, model=None): conversations = [] for problem in data: - prompt_text = self.generate_prompt(problem[self.question_key], model) + prompt_text = self.generate_prompt(problem, model) conversations.append( [ {"role": "system", "content": system_prompt}, diff --git a/skythought/tools/tasks/apps/apps_handler.py b/skythought/tools/tasks/apps/apps_handler.py index c06ac45..b849e25 100644 --- a/skythought/tools/tasks/apps/apps_handler.py +++ b/skythought/tools/tasks/apps/apps_handler.py @@ -18,13 +18,13 @@ def generate_prompt(self, test_case, prompt, starter_code=None): "with_fn_name_template" ].format( prompt=prompt - ) # "\nUse Standard Input format"#\n" + ) else: _input = self.task_config.templating_parameters[ "without_fn_name_template" ].format( prompt=prompt - ) # "\nUse Call-Based format"#\n" + ) if starter_code is not None: _input = self.task_config.templating_parameters[ "with_starter_code_template" diff --git a/skythought/tools/tasks/gpqa_diamond/gpqa_diamond_handler.py b/skythought/tools/tasks/gpqa_diamond/gpqa_diamond_handler.py index 50c5092..010c1c4 100644 --- a/skythought/tools/tasks/gpqa_diamond/gpqa_diamond_handler.py +++ b/skythought/tools/tasks/gpqa_diamond/gpqa_diamond_handler.py @@ -5,8 +5,6 @@ class GPQADiamondTaskHandler(TaskHandler): - def __init__(self): - self.dataset = "Idavidrein/gpqa" def generate_prompt(self, problem): return self.task_config.templating_parameters["template"].format(**problem) diff --git a/skythought/tools/tasks/gsm8k/gsm8k_handler.py b/skythought/tools/tasks/gsm8k/gsm8k_handler.py index 4ab52ff..af811b5 100644 --- a/skythought/tools/tasks/gsm8k/gsm8k_handler.py +++ b/skythought/tools/tasks/gsm8k/gsm8k_handler.py @@ -6,9 +6,8 @@ class GSM8KTaskHandler(TaskHandler): - def __init__(self) -> None: - super().__init__() - self.dataset = "openai/gsm8k" + def __init__(self, yaml_file_path: str) -> None: + super().__init__(yaml_file_path) self.ans_re = re.compile(r"((-?[$0-9.,]{2,})|(-?[0-9]+))") self.gt_re = re.compile(r"#### (\-?[0-9\.\,]+)") self.invalid_ans = "[invalid]" From 995f1a5af95b206b9ea7c118ccbb98549111a264 Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Fri, 24 Jan 2025 20:04:44 +0000 Subject: [PATCH 13/47] more more more Signed-off-by: SumanthRH --- skythought/tools/eval.py | 19 +++----- skythought/tools/inference_and_check.py | 45 +++++++++---------- skythought/tools/tasks/__init__.py | 31 +++++++------ skythought/tools/tasks/aime/aime.yaml | 4 +- skythought/tools/tasks/amc23/amc23.yaml | 3 +- skythought/tools/tasks/apps/apps.yaml | 3 +- skythought/tools/tasks/arc/arc_c.yaml | 3 +- skythought/tools/tasks/arc/arc_handler.py | 6 +-- skythought/tools/tasks/common.py | 33 +++++++++----- .../tasks/gpqa_diamond/gpqa_diamond.yaml | 3 +- skythought/tools/tasks/gsm8k/gsm8k.yaml | 3 +- skythought/tools/tasks/gsm8k/gsm8k_handler.py | 6 +-- .../tasks/livecodebench/livecodebench.yaml | 3 +- .../livecodebench/livecodebench_handler.py | 2 +- skythought/tools/tasks/math/math500.yaml | 3 +- skythought/tools/tasks/mmlu/mmlu.yaml | 3 +- skythought/tools/tasks/mmlu/mmlu_handler.py | 6 +-- skythought/tools/tasks/mmlu/mmlu_pro.yaml | 3 +- skythought/tools/tasks/numina/numina.yaml | 3 +- .../tools/tasks/numina/numina_handler.py | 2 +- skythought/tools/tasks/taco/taco.yaml | 3 +- skythought/tools/tasks/taco/taco_handler.py | 1 - skythought/tools/tasks/task_util.py | 16 +++++++ 23 files changed, 115 insertions(+), 89 deletions(-) create mode 100644 skythought/tools/tasks/task_util.py diff --git a/skythought/tools/eval.py b/skythought/tools/eval.py index b577553..f5fee49 100644 --- a/skythought/tools/eval.py +++ b/skythought/tools/eval.py @@ -1,20 +1,12 @@ import argparse import json import subprocess +import os -# Define eval to split mapping -eval_to_split = { - "MATH500": "test", - "AIME": "train", - "GPQADiamond": "train", - "MMLU": "test", - "MMLUPro": "test", - "LiveCodeBench": "test", - "GSM8K": "test", - "ARC-C": "test", - "AMC23": "train", -} +from skythought.tools.tasks.task_util import get_tasks +module_dir = os.path.dirname(os.path.abspath(__file__)) +TASK_NAMES_TO_YAML = get_tasks(os.path.join(module_dir, "tasks")) def parse_arguments(): parser = argparse.ArgumentParser( @@ -89,6 +81,7 @@ def main(): # Run the Python command for each eval and collect logs for eval_name in evals: + eval_name = eval_name.lower() command = [ "python", script_path, @@ -96,8 +89,6 @@ def main(): model_path, "--dataset", eval_name, - "--split", - eval_to_split[eval_name], "--tp", str(tp), "--temperatures", diff --git a/skythought/tools/inference_and_check.py b/skythought/tools/inference_and_check.py index fdb46d5..5682962 100644 --- a/skythought/tools/inference_and_check.py +++ b/skythought/tools/inference_and_check.py @@ -10,9 +10,12 @@ from tqdm import tqdm from vllm import LLM, SamplingParams -from tasks import TASK_HANDLERS, NUMINATaskHandler, TaskHandler +from tasks import TASK_HANDLER_MAP, NUMINATaskHandler, TaskHandler +from tasks.task_util import get_tasks from util.model_utils import MODEL_TO_NAME, SYSTEM_PROMPT +module_dir = os.path.dirname(os.path.abspath(__file__)) +TASK_NAMES_TO_YAML = get_tasks(os.path.join(module_dir, "tasks")) class NumpyEncoder(json.JSONEncoder): def default(self, obj): @@ -417,26 +420,12 @@ def main(): description="Unified inference and checking for different datasets/tasks." ) parser.add_argument( - "--dataset", + "--task", type=str, required=True, - choices=[ - "NUMINA", - "APPS", - "TACO", - "MATH500", - "AIME", - "GPQADiamond", - "MMLU", - "MMLUPro", - "LiveCodeBench", - "GSM8K", - "ARC-C", - "AMC23", - ], - help="Dataset to process.", + choices=TASK_NAMES_TO_YAML.keys(), + help="Task to process.", ) - parser.add_argument("--config", type=str, help="Path to the config file.") parser.add_argument( "--model", type=str, @@ -451,7 +440,7 @@ def main(): parser.add_argument( "--split", type=str, - default="train", + default=None, help="Split to use for apps (e.g., train, test).", ) parser.add_argument("--source", type=str, help="Source for the dataset.") @@ -493,7 +482,10 @@ def main(): ) args = parser.parse_args() - handler: TaskHandler = TASK_HANDLERS[args.dataset](args.config) + handler_cls: TaskHandler = TASK_HANDLER_MAP[args.task] + config_path = TASK_NAMES_TO_YAML[args.task] + handler = handler_cls.from_config_path(config_path) + temperatures = [1] if args.model.startswith("openai/o1") else args.temperatures print(f"Temperature: {temperatures}") @@ -502,6 +494,11 @@ def main(): args.n = 1 print("Warning: Temperature 0 does not support multiple samples. Setting n=1.") + # TODO: this can be cleaned up by allowing user override for any task_config with optional task_args + # Currently kept here for consistency with old code + args.split = args.split if args.split else handler.task_config.dataset_split + args.source = args.source if args.source else handler.task_config.dataset_source + # create result dir if not exists if args.result_dir and not os.path.exists(args.result_dir): os.makedirs(args.result_dir) @@ -511,12 +508,12 @@ def main(): ): result_file = os.path.join( args.result_dir, - f"{MODEL_TO_NAME[args.model]}_{args.dataset}_{args.split}_{args.source}_{args.start}_{args.end}_{args.math_difficulty_lower_bound}_{args.math_difficulty_upper_bound}.json", + f"{MODEL_TO_NAME[args.model]}_{args.task}_{args.split}_{args.source}_{args.start}_{args.end}_{args.math_difficulty_lower_bound}_{args.math_difficulty_upper_bound}.json", ) else: result_file = os.path.join( args.result_dir, - f"{MODEL_TO_NAME[args.model]}_{args.dataset}_{args.split}_{args.source}_{args.start}_{args.end}.json", + f"{MODEL_TO_NAME[args.model]}_{args.task}_{args.split}_{args.source}_{args.start}_{args.end}.json", ) if args.check: @@ -525,9 +522,9 @@ def main(): args.math_difficulty_lower_bound is not None or args.math_difficulty_upper_bound is not None ): - converted_file = f"{args.result_dir}/converted_{MODEL_TO_NAME[args.model]}_{args.dataset}_{args.split}_{args.source}_{args.start}_{args.end}_{args.math_difficulty_lower_bound}_{args.math_difficulty_upper_bound}.json" + converted_file = f"{args.result_dir}/converted_{MODEL_TO_NAME[args.model]}_{args.task}_{args.split}_{args.source}_{args.start}_{args.end}_{args.math_difficulty_lower_bound}_{args.math_difficulty_upper_bound}.json" else: - converted_file = f"{args.result_dir}/converted_{MODEL_TO_NAME[args.model]}_{args.dataset}_{args.split}_{args.source}_{args.start}_{args.end}.json" + converted_file = f"{args.result_dir}/converted_{MODEL_TO_NAME[args.model]}_{args.task}_{args.split}_{args.source}_{args.start}_{args.end}.json" if os.path.exists(converted_file): result_file = converted_file perform_check(handler, temperatures, result_file, args) diff --git a/skythought/tools/tasks/__init__.py b/skythought/tools/tasks/__init__.py index b20b406..80f357a 100644 --- a/skythought/tools/tasks/__init__.py +++ b/skythought/tools/tasks/__init__.py @@ -2,7 +2,7 @@ from .amc23.amc23_handler import AMC23TaskHandler from .apps.apps_handler import APPSTaskHandler from .arc.arc_handler import ARCChallengeTaskHandler -from .common import TaskHandler +from .common import TaskHandler, TaskConfig from .gpqa_diamond.gpqa_diamond_handler import GPQADiamondTaskHandler from .gsm8k.gsm8k_handler import GSM8KTaskHandler from .livecodebench.livecodebench_handler import LiveCodeBenchTaskHandler @@ -11,22 +11,21 @@ from .numina.numina_handler import NUMINATaskHandler from .taco.taco_handler import TACOTaskHandler -TASK_HANDLERS = { - "NUMINA": NUMINATaskHandler, - "APPS": APPSTaskHandler, - "TACO": TACOTaskHandler, - "MATH500": MathTaskHandler, - "AIME": AIMETaskHandler, - "GPQADiamond": GPQADiamondTaskHandler, - "MMLU": MMLUTaskHandler, - "MMLUPro": MMLUProTaskHandler, - "LiveCodeBench": LiveCodeBenchTaskHandler, - "GSM8K": GSM8KTaskHandler, - "ARC-C": ARCChallengeTaskHandler, - "AMC23": AMC23TaskHandler, +TASK_HANDLER_MAP = { + "numina": NUMINATaskHandler, + "apps": APPSTaskHandler, + "taco": TACOTaskHandler, + "math500": MathTaskHandler, + "aime": AIMETaskHandler, + "gpqa_diamond": GPQADiamondTaskHandler, + "mmlu": MMLUTaskHandler, + "mmlu_pro": MMLUProTaskHandler, + "livecodebench": LiveCodeBenchTaskHandler, + "gsm8k": GSM8KTaskHandler, + "arc_c": ARCChallengeTaskHandler, + "amc23": AMC23TaskHandler, } - __all__ = [ AIMETaskHandler, APPSTaskHandler, @@ -42,5 +41,5 @@ ARCChallengeTaskHandler, TaskHandler, MathTaskHandler, - TASK_HANDLERS, + TASK_HANDLER_MAP, ] diff --git a/skythought/tools/tasks/aime/aime.yaml b/skythought/tools/tasks/aime/aime.yaml index b2e1fcc..5c2aa83 100644 --- a/skythought/tools/tasks/aime/aime.yaml +++ b/skythought/tools/tasks/aime/aime.yaml @@ -1,5 +1,5 @@ - -dataset_source: AI-MO/aimo-validation-aime +handler: aime +dataset_path: AI-MO/aimo-validation-aime dataset_split: train question_key: problem templating_parameters: diff --git a/skythought/tools/tasks/amc23/amc23.yaml b/skythought/tools/tasks/amc23/amc23.yaml index 6430bf7..627e31d 100644 --- a/skythought/tools/tasks/amc23/amc23.yaml +++ b/skythought/tools/tasks/amc23/amc23.yaml @@ -1,4 +1,5 @@ -dataset_name: AI-MO/aimo-validation-amc +handler: amc23 +dataset_path: AI-MO/aimo-validation-amc dataset_kwargs: trust_remote_code: true dataset_split: train diff --git a/skythought/tools/tasks/apps/apps.yaml b/skythought/tools/tasks/apps/apps.yaml index 4846504..1ee98df 100644 --- a/skythought/tools/tasks/apps/apps.yaml +++ b/skythought/tools/tasks/apps/apps.yaml @@ -1,4 +1,5 @@ -dataset_name: codeparrot/apps +handler: apps +dataset_path: codeparrot/apps dataset_kwargs: trust_remote_code: true dataset_split: train diff --git a/skythought/tools/tasks/arc/arc_c.yaml b/skythought/tools/tasks/arc/arc_c.yaml index 2186578..126e165 100644 --- a/skythought/tools/tasks/arc/arc_c.yaml +++ b/skythought/tools/tasks/arc/arc_c.yaml @@ -1,4 +1,5 @@ -dataset_name: allenai/ai2_arc +handler: arc_c +dataset_path: allenai/ai2_arc dataset_source: ARC-Challenge dataset_split: train question_key: question diff --git a/skythought/tools/tasks/arc/arc_handler.py b/skythought/tools/tasks/arc/arc_handler.py index eb2dc87..8efc511 100644 --- a/skythought/tools/tasks/arc/arc_handler.py +++ b/skythought/tools/tasks/arc/arc_handler.py @@ -1,13 +1,13 @@ import re from typing import Any, Dict -from tasks.common import TaskHandler +from tasks.common import TaskHandler, TaskConfig from util.math_parsing_util import extract_answer class ARCChallengeTaskHandler(TaskHandler): - def __init__(self, yaml_file_path) -> None: - super().__init__(yaml_file_path) + def __init__(self, task_config: TaskConfig) -> None: + super().__init__(task_config) self.ans_re = re.compile(r"[Tt]he best answer is ([A-D])[\.\,]*", re.IGNORECASE) self.letter_re = re.compile(r"([A-D])[\.\,]*") self.canonical_options = ["A", "B", "C", "D"] diff --git a/skythought/tools/tasks/common.py b/skythought/tools/tasks/common.py index 47ef30d..602eadc 100644 --- a/skythought/tools/tasks/common.py +++ b/skythought/tools/tasks/common.py @@ -9,7 +9,8 @@ class TaskConfig(BaseModel): - dataset_name: str + handler: str + dataset_path: str dataset_source: Optional[str] = None dataset_split: str dataset_kwargs: Optional[Dict[str, Any]] = None @@ -19,18 +20,30 @@ class TaskConfig(BaseModel): fewshot_config: List[Dict[str, Any]] = Field(default_factory=list) num_fewshot: int = 0 + @property + def handler_cls(self): + from tasks import TASK_HANDLER_MAP + + return TASK_HANDLER_MAP[self.handler] + + @classmethod + def from_yaml(cls, yaml_file_path) -> "TaskConfig": + with open(yaml_file_path, "r", encoding="utf-8") as f: + config_dict = yaml.safe_load(f) + return cls(**config_dict) + + class TaskHandler: task_config_cls = TaskConfig - def __init__(self, yaml_file_path): - self.yaml_file_path = yaml_file_path - self.task_config = self.task_config_cls(**self.load_yaml(yaml_file_path)) - - @staticmethod - def load_yaml(yaml_file_path): - with open(yaml_file_path, "r", encoding="utf-8") as f: - return yaml.safe_load(f) + def __init__(self, task_config: TaskConfig): + self.task_config = task_config + + @classmethod + def from_config_path(cls, config_path: str) -> "TaskHandler": + task_config = cls.task_config_cls.from_yaml(config_path) + return cls(task_config) @property def question_key(self): @@ -54,7 +67,7 @@ def load_existing_results(self, result_file): def load_dataset(self, source=None, split=None, **kwargs) -> HFDataset: dataset = load_dataset( - path=self.task_config.dataset_name, + path=self.task_config.dataset_path, name=source if source else self.task_config.dataset_source, split=split if split else self.task_config.dataset_split, **self.task_config.dataset_kwargs diff --git a/skythought/tools/tasks/gpqa_diamond/gpqa_diamond.yaml b/skythought/tools/tasks/gpqa_diamond/gpqa_diamond.yaml index 4f577ab..3ba7c7c 100644 --- a/skythought/tools/tasks/gpqa_diamond/gpqa_diamond.yaml +++ b/skythought/tools/tasks/gpqa_diamond/gpqa_diamond.yaml @@ -1,4 +1,5 @@ -dataset_name: Idavidrein/gpqa +handler: gpqa_diamond +dataset_path: Idavidrein/gpqa dataset_source: gpqa_diamond dataset_split: train question_key: Question diff --git a/skythought/tools/tasks/gsm8k/gsm8k.yaml b/skythought/tools/tasks/gsm8k/gsm8k.yaml index bcd5920..995106f 100644 --- a/skythought/tools/tasks/gsm8k/gsm8k.yaml +++ b/skythought/tools/tasks/gsm8k/gsm8k.yaml @@ -1,4 +1,5 @@ -dataset_name: "openai/gsm8k" +handler: gsm8k +dataset_path: "openai/gsm8k" dataset_source: main dataset_split: test question_key: question diff --git a/skythought/tools/tasks/gsm8k/gsm8k_handler.py b/skythought/tools/tasks/gsm8k/gsm8k_handler.py index af811b5..94b1e3d 100644 --- a/skythought/tools/tasks/gsm8k/gsm8k_handler.py +++ b/skythought/tools/tasks/gsm8k/gsm8k_handler.py @@ -1,13 +1,13 @@ import re from typing import Any, Dict -from tasks.common import TaskHandler +from tasks.common import TaskHandler, TaskConfig from util.math_parsing_util import extract_answer class GSM8KTaskHandler(TaskHandler): - def __init__(self, yaml_file_path: str) -> None: - super().__init__(yaml_file_path) + def __init__(self, task_config: TaskConfig) -> None: + super().__init__(task_config) self.ans_re = re.compile(r"((-?[$0-9.,]{2,})|(-?[0-9]+))") self.gt_re = re.compile(r"#### (\-?[0-9\.\,]+)") self.invalid_ans = "[invalid]" diff --git a/skythought/tools/tasks/livecodebench/livecodebench.yaml b/skythought/tools/tasks/livecodebench/livecodebench.yaml index 910debb..427aadc 100644 --- a/skythought/tools/tasks/livecodebench/livecodebench.yaml +++ b/skythought/tools/tasks/livecodebench/livecodebench.yaml @@ -1,4 +1,5 @@ -dataset_name: "livecodebench/code_generation_lite" # repo ID in huggingface +handler: livecodebench +dataset_path: "livecodebench/code_generation_lite" # repo ID in huggingface dataset_source: null dataset_split: test dataset_kwargs: diff --git a/skythought/tools/tasks/livecodebench/livecodebench_handler.py b/skythought/tools/tasks/livecodebench/livecodebench_handler.py index 56d2aea..b5c8b6f 100644 --- a/skythought/tools/tasks/livecodebench/livecodebench_handler.py +++ b/skythought/tools/tasks/livecodebench/livecodebench_handler.py @@ -17,7 +17,7 @@ class LiveCodeBenchTaskConfig(TaskConfig): class LiveCodeBenchTaskHandler(TaskHandler): task_config_cls = LiveCodeBenchTaskConfig - + def generate_prompt(self, problem): if problem["is_stdin"]: return self.task_config.templating_parameters["stdin_template"].format( diff --git a/skythought/tools/tasks/math/math500.yaml b/skythought/tools/tasks/math/math500.yaml index 24c8927..6eb73f2 100644 --- a/skythought/tools/tasks/math/math500.yaml +++ b/skythought/tools/tasks/math/math500.yaml @@ -1,4 +1,5 @@ -dataset_name: "qq8933/MATH500" # repo ID in huggingface +handler: math500 +dataset_path: "qq8933/MATH500" # repo ID in huggingface dataset_source: null # which subset on huggingface question_key: problem dataset_split: test diff --git a/skythought/tools/tasks/mmlu/mmlu.yaml b/skythought/tools/tasks/mmlu/mmlu.yaml index bae4caf..0f4f1d8 100644 --- a/skythought/tools/tasks/mmlu/mmlu.yaml +++ b/skythought/tools/tasks/mmlu/mmlu.yaml @@ -1,4 +1,5 @@ -dataset_name: "cais/mmlu" +handler: mmlu +dataset_path: "cais/mmlu" dataset_source: default dataset_split: test templating_parameters: diff --git a/skythought/tools/tasks/mmlu/mmlu_handler.py b/skythought/tools/tasks/mmlu/mmlu_handler.py index 0739122..b0749a1 100644 --- a/skythought/tools/tasks/mmlu/mmlu_handler.py +++ b/skythought/tools/tasks/mmlu/mmlu_handler.py @@ -2,7 +2,7 @@ from util.math_parsing_util import get_multiple_choice_answer, mmlu_pro_extract_answer -from ..common import TaskHandler +from ..common import TaskHandler, TaskConfig class MMLUTaskHandler(TaskHandler): @@ -71,8 +71,8 @@ def load_and_filter_dataset( class MMLUProTaskHandler(MMLUTaskHandler): - def __init__(self, yaml_file_path): - super().__init__(yaml_file_path) + def __init__(self, task_config: TaskConfig): + super().__init__(task_config) self.choices = [ "A", "B", diff --git a/skythought/tools/tasks/mmlu/mmlu_pro.yaml b/skythought/tools/tasks/mmlu/mmlu_pro.yaml index 7305cf3..b744e80 100644 --- a/skythought/tools/tasks/mmlu/mmlu_pro.yaml +++ b/skythought/tools/tasks/mmlu/mmlu_pro.yaml @@ -1,4 +1,5 @@ -dataset_name: TIGER-Lab/MMLU-Pro +handler: mmlu_pro +dataset_path: TIGER-Lab/MMLU-Pro dataset_source: default dataset_split: test templating_parameters: diff --git a/skythought/tools/tasks/numina/numina.yaml b/skythought/tools/tasks/numina/numina.yaml index f9849b0..3eed955 100644 --- a/skythought/tools/tasks/numina/numina.yaml +++ b/skythought/tools/tasks/numina/numina.yaml @@ -1,4 +1,5 @@ -dataset_name: "AI-MO/NuminaMath-CoT" +handler: numina +dataset_path: "AI-MO/NuminaMath-CoT" dataset_source: default dataset_split: train question_key: problem diff --git a/skythought/tools/tasks/numina/numina_handler.py b/skythought/tools/tasks/numina/numina_handler.py index b941220..0bf8df5 100644 --- a/skythought/tools/tasks/numina/numina_handler.py +++ b/skythought/tools/tasks/numina/numina_handler.py @@ -13,7 +13,7 @@ class NUMINATaskConfig(TaskConfig): class NUMINATaskHandler(TaskHandler): task_config_cls = NUMINATaskConfig - + def generate_prompt(self, prompt): return self.task_config.templating_parameters["template"].format(prompt=prompt) diff --git a/skythought/tools/tasks/taco/taco.yaml b/skythought/tools/tasks/taco/taco.yaml index 4a7b6a1..5c02b4e 100644 --- a/skythought/tools/tasks/taco/taco.yaml +++ b/skythought/tools/tasks/taco/taco.yaml @@ -1,4 +1,5 @@ -dataset_name: "BAAI/TACO" +handler: taco +dataset_path: "BAAI/TACO" dataset_source: default dataset_split: ALL dataset_kwargs: diff --git a/skythought/tools/tasks/taco/taco_handler.py b/skythought/tools/tasks/taco/taco_handler.py index fc014ee..1cd8b6e 100644 --- a/skythought/tools/tasks/taco/taco_handler.py +++ b/skythought/tools/tasks/taco/taco_handler.py @@ -16,7 +16,6 @@ class TACOTaskConfig(TaskConfig): class TACOTaskHandler(TaskHandler): - task_config_cls = TACOTaskConfig def generate_prompt(self, prompt, starter_code=None, fn_name=None): _input = self.task_config.templating_parameters["initial_template"].format( diff --git a/skythought/tools/tasks/task_util.py b/skythought/tools/tasks/task_util.py new file mode 100644 index 0000000..5a736c8 --- /dev/null +++ b/skythought/tools/tasks/task_util.py @@ -0,0 +1,16 @@ +import glob +import os +from typing import Dict + +def get_tasks(task_root_dir: str) -> Dict[str, str]: + """Returns a dictionary of task names and their corresponding yaml file paths""" + # list all yamls in subdirectories + name_to_yaml = {} + for yaml_file in glob.glob(os.path.join(task_root_dir, "**", "*.yaml"), recursive=True): + # arc.yaml -> arc + name = os.path.basename(yaml_file).split(".")[0] + + name_to_yaml[name] = yaml_file + + return name_to_yaml + From 7bafae14e5fb4300bc8dda0ef82d5ca13ff35cbd Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Fri, 24 Jan 2025 22:47:12 +0000 Subject: [PATCH 14/47] mid commit Signed-off-by: SumanthRH --- skythought/tools/eval.py | 10 ++++++---- skythought/tools/inference_and_check.py | 5 ++++- skythought/tools/tasks/aime/aime_handler.py | 2 +- skythought/tools/tasks/amc23/amc23_handler.py | 2 +- skythought/tools/tasks/apps/apps_handler.py | 4 ++-- skythought/tools/tasks/arc/arc_handler.py | 2 +- skythought/tools/tasks/common.py | 2 +- .../tools/tasks/gpqa_diamond/gpqa_diamond_handler.py | 2 +- skythought/tools/tasks/gsm8k/gsm8k_handler.py | 2 +- .../tools/tasks/livecodebench/livecodebench_handler.py | 4 ++-- skythought/tools/tasks/math/math_handler.py | 2 +- skythought/tools/tasks/mmlu/mmlu_handler.py | 4 ++-- skythought/tools/tasks/numina/numina_handler.py | 4 ++-- skythought/tools/tasks/taco/taco_handler.py | 2 +- 14 files changed, 26 insertions(+), 21 deletions(-) diff --git a/skythought/tools/eval.py b/skythought/tools/eval.py index f5fee49..73027c7 100644 --- a/skythought/tools/eval.py +++ b/skythought/tools/eval.py @@ -21,7 +21,10 @@ def parse_arguments(): ) parser.add_argument("--tp", type=int, default=8, help="Tensor Parallelism Degree") parser.add_argument( - "--filter-difficulty", action="store_true", help="Filter difficulty." + "--filter-difficulty", + type=str, + default=None, + help="Optional filter difficulty. Options: 'easy', 'medium', 'hard'.", ) parser.add_argument("--source", type=str, help="Source for the dataset.") parser.add_argument( @@ -96,10 +99,9 @@ def main(): command.extend(temperatures) # Add temperatures as separate arguments if args.filter_difficulty: - assert args.source != "", "No source passed for filtering difficulty." command.append("--filter-difficulty") - command.append("--source") - command.append(args.source) + command.append(args.filter_difficulty) + print(f"Running eval {eval_name} with command {command}") all_logs += f"\nRunning eval: {eval_name} with command {command}\n" try: diff --git a/skythought/tools/inference_and_check.py b/skythought/tools/inference_and_check.py index 5682962..fcba2c7 100644 --- a/skythought/tools/inference_and_check.py +++ b/skythought/tools/inference_and_check.py @@ -447,7 +447,10 @@ def main(): parser.add_argument("--start", type=int, default=0, help="Start index.") parser.add_argument("--end", type=int, default=-1, help="End index.") parser.add_argument( - "--filter-difficulty", action="store_true", help="Filter difficulty." + "--filter-difficulty", + type=str, + default=None, + help="Optional filter difficulty. Options: 'easy', 'medium', 'hard'.", ) parser.add_argument( "--result-dir", type=str, default="./", help="Result dir to save files." diff --git a/skythought/tools/tasks/aime/aime_handler.py b/skythought/tools/tasks/aime/aime_handler.py index 4b11a19..5b7b6d7 100644 --- a/skythought/tools/tasks/aime/aime_handler.py +++ b/skythought/tools/tasks/aime/aime_handler.py @@ -26,7 +26,7 @@ def make_conversations(self, data, system_prompt, model=None): return conversations def load_and_filter_dataset( - self, start, end, split=None, source=None, filter_difficulty=False, args=None + self, start, end, split=None, source=None, filter_difficulty=None, args=None ): train_data = self.load_dataset(source=source, split=split).to_pandas() filtered_data = train_data[train_data["url"].str.contains("2024", na=False)] diff --git a/skythought/tools/tasks/amc23/amc23_handler.py b/skythought/tools/tasks/amc23/amc23_handler.py index 5b765cc..ff598fb 100644 --- a/skythought/tools/tasks/amc23/amc23_handler.py +++ b/skythought/tools/tasks/amc23/amc23_handler.py @@ -3,7 +3,7 @@ class AMC23TaskHandler(MathTaskHandler): def load_and_filter_dataset( - self, start, end, split=None, source=None, filter_difficulty=False, args=None + self, start, end, split=None, source=None, filter_difficulty=None, args=None ): train_data = self.load_dataset(source=source, split=split).to_pandas() filtered_data = train_data[train_data["url"].str.contains("2023", na=False)] diff --git a/skythought/tools/tasks/apps/apps_handler.py b/skythought/tools/tasks/apps/apps_handler.py index b849e25..fc87649 100644 --- a/skythought/tools/tasks/apps/apps_handler.py +++ b/skythought/tools/tasks/apps/apps_handler.py @@ -102,12 +102,12 @@ def make_conversations(self, data, system_prompt, model=None): return conversations def load_and_filter_dataset( - self, start, end, split=None, source=None, filter_difficulty=False, args=None + self, start, end, split=None, source=None, filter_difficulty=None, args=None ): train_data = self.load_dataset(source=source, split=split).to_pandas() if filter_difficulty or self.task_config.difficulty: difficulty = ( - self.task_config.difficulty if not filter_difficulty else source + self.task_config.difficulty if not filter_difficulty else filter_difficulty ) train_data = train_data.filter(lambda x: x["difficulty"] == difficulty) diff --git a/skythought/tools/tasks/arc/arc_handler.py b/skythought/tools/tasks/arc/arc_handler.py index 8efc511..d403f46 100644 --- a/skythought/tools/tasks/arc/arc_handler.py +++ b/skythought/tools/tasks/arc/arc_handler.py @@ -66,7 +66,7 @@ def make_conversations(self, data, system_prompt, model=None): return conversations def load_and_filter_dataset( - self, start, end, split=None, source=None, filter_difficulty=False, args=None + self, start, end, split=None, source=None, filter_difficulty=None, args=None ): train_data = self.load_dataset(source=source, split=split).to_pandas() return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] diff --git a/skythought/tools/tasks/common.py b/skythought/tools/tasks/common.py index 602eadc..49d88e7 100644 --- a/skythought/tools/tasks/common.py +++ b/skythought/tools/tasks/common.py @@ -75,7 +75,7 @@ def load_dataset(self, source=None, split=None, **kwargs) -> HFDataset: return dataset def load_and_filter_dataset( - self, start, end, split="train", source=None, filter_difficulty=False, args=None + self, start, end, split="train", source=None, filter_difficulty=None, args=None ): raise NotImplementedError("Subclasses should implement this method.") diff --git a/skythought/tools/tasks/gpqa_diamond/gpqa_diamond_handler.py b/skythought/tools/tasks/gpqa_diamond/gpqa_diamond_handler.py index 010c1c4..017e8c3 100644 --- a/skythought/tools/tasks/gpqa_diamond/gpqa_diamond_handler.py +++ b/skythought/tools/tasks/gpqa_diamond/gpqa_diamond_handler.py @@ -81,7 +81,7 @@ def make_conversations(self, data, system_prompt, model=None): return conversations def load_and_filter_dataset( - self, start, end, split=None, source=None, filter_difficulty=False, args=None + self, start, end, split=None, source=None, filter_difficulty=None, args=None ): train_data = self.load_dataset(source=source, split=split).to_pandas() return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] diff --git a/skythought/tools/tasks/gsm8k/gsm8k_handler.py b/skythought/tools/tasks/gsm8k/gsm8k_handler.py index 94b1e3d..b2042e6 100644 --- a/skythought/tools/tasks/gsm8k/gsm8k_handler.py +++ b/skythought/tools/tasks/gsm8k/gsm8k_handler.py @@ -53,7 +53,7 @@ def make_conversations(self, data, system_prompt, model=None): return conversations def load_and_filter_dataset( - self, start, end, split=None, source=None, filter_difficulty=False, args=None + self, start, end, split=None, source=None, filter_difficulty=None, args=None ): train_data = self.load_dataset(source=source, split=split).to_pandas() return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] diff --git a/skythought/tools/tasks/livecodebench/livecodebench_handler.py b/skythought/tools/tasks/livecodebench/livecodebench_handler.py index b5c8b6f..f16e8e8 100644 --- a/skythought/tools/tasks/livecodebench/livecodebench_handler.py +++ b/skythought/tools/tasks/livecodebench/livecodebench_handler.py @@ -17,7 +17,7 @@ class LiveCodeBenchTaskConfig(TaskConfig): class LiveCodeBenchTaskHandler(TaskHandler): task_config_cls = LiveCodeBenchTaskConfig - + def generate_prompt(self, problem): if problem["is_stdin"]: return self.task_config.templating_parameters["stdin_template"].format( @@ -101,7 +101,7 @@ def make_conversations(self, data, system_prompt, model=None): return conversations def load_and_filter_dataset( - self, start, end, split=None, source=None, filter_difficulty=False, args=None + self, start, end, split=None, source=None, filter_difficulty=None, args=None ): dataset = self.load_dataset(source=source, split=split) # Filter by CLI or config diff --git a/skythought/tools/tasks/math/math_handler.py b/skythought/tools/tasks/math/math_handler.py index 3d5e226..3d9a550 100644 --- a/skythought/tools/tasks/math/math_handler.py +++ b/skythought/tools/tasks/math/math_handler.py @@ -51,7 +51,7 @@ def process_remaining_data(self, train_data, results): ] def load_and_filter_dataset( - self, start, end, split=None, source=None, filter_difficulty=False, args=None + self, start, end, split=None, source=None, filter_difficulty=None, args=None ): dataset = self.load_dataset(source=source, split=split) return dataset.iloc[start:end] if end > 0 else dataset.iloc[start:] diff --git a/skythought/tools/tasks/mmlu/mmlu_handler.py b/skythought/tools/tasks/mmlu/mmlu_handler.py index b0749a1..16991d9 100644 --- a/skythought/tools/tasks/mmlu/mmlu_handler.py +++ b/skythought/tools/tasks/mmlu/mmlu_handler.py @@ -63,7 +63,7 @@ def process_remaining_data(self, train_data, results): ] def load_and_filter_dataset( - self, start, end, split="test", source=None, filter_difficulty=False, args=None + self, start, end, split="test", source=None, filter_difficulty=None, args=None ): dataset = load_dataset(self.dataset, "all") train_data = dataset[split].to_pandas() @@ -108,7 +108,7 @@ def get_multiple_choice_answers(self, problem): return f"Answer Choices: {options}" def load_and_filter_dataset( - self, start, end, split="test", source=None, filter_difficulty=False, args=None + self, start, end, split="test", source=None, filter_difficulty=None, args=None ): dataset = self.load_dataset(source=source, split=split) return dataset.iloc[start:end] if end > 0 else dataset.iloc[start:] diff --git a/skythought/tools/tasks/numina/numina_handler.py b/skythought/tools/tasks/numina/numina_handler.py index 0bf8df5..f19742c 100644 --- a/skythought/tools/tasks/numina/numina_handler.py +++ b/skythought/tools/tasks/numina/numina_handler.py @@ -13,7 +13,7 @@ class NUMINATaskConfig(TaskConfig): class NUMINATaskHandler(TaskHandler): task_config_cls = NUMINATaskConfig - + def generate_prompt(self, prompt): return self.task_config.templating_parameters["template"].format(prompt=prompt) @@ -75,7 +75,7 @@ def make_conversations(self, data, system_prompt, model=None): return conversations def load_and_filter_dataset( - self, start, end, split="train", source=None, filter_difficulty=False, args=None + self, start, end, split="train", source=None, filter_difficulty=None, args=None ): dataset = self.load_dataset(source=source, split=split).to_pandas() dataset = dataset.iloc[start:end] if end > 0 else dataset.iloc[start:] diff --git a/skythought/tools/tasks/taco/taco_handler.py b/skythought/tools/tasks/taco/taco_handler.py index 1cd8b6e..2cc222a 100644 --- a/skythought/tools/tasks/taco/taco_handler.py +++ b/skythought/tools/tasks/taco/taco_handler.py @@ -113,7 +113,7 @@ def make_conversations(self, data, system_prompt, model=None): return conversations def load_and_filter_dataset( - self, start, end, split="train", source=None, filter_difficulty=False, args=None + self, start, end, split="train", source=None, filter_difficulty=None, args=None ): dataset = self.load_dataset(source=source, split=split).to_pandas() if filter_difficulty or self.task_config.difficulty: From 829cb4ce916919494ca119a3c3ad1d559821d1e9 Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Sat, 25 Jan 2025 01:46:54 +0000 Subject: [PATCH 15/47] x Signed-off-by: SumanthRH --- skythought/tools/eval.py | 5 +++-- skythought/tools/tasks/aime/aime_handler.py | 7 ++++--- skythought/tools/tasks/common.py | 2 +- skythought/tools/tasks/math/math500.yaml | 12 ++++++------ skythought/tools/tasks/math/math_handler.py | 2 +- 5 files changed, 15 insertions(+), 13 deletions(-) diff --git a/skythought/tools/eval.py b/skythought/tools/eval.py index 73027c7..71f69b0 100644 --- a/skythought/tools/eval.py +++ b/skythought/tools/eval.py @@ -3,7 +3,7 @@ import subprocess import os -from skythought.tools.tasks.task_util import get_tasks +from tasks.task_util import get_tasks module_dir = os.path.dirname(os.path.abspath(__file__)) TASK_NAMES_TO_YAML = get_tasks(os.path.join(module_dir, "tasks")) @@ -85,12 +85,13 @@ def main(): # Run the Python command for each eval and collect logs for eval_name in evals: eval_name = eval_name.lower() + assert eval_name in TASK_NAMES_TO_YAML.keys(), f"Task {eval_name} not found, should be one of {TASK_NAMES_TO_YAML.keys()}" command = [ "python", script_path, "--model", model_path, - "--dataset", + "--task", eval_name, "--tp", str(tp), diff --git a/skythought/tools/tasks/aime/aime_handler.py b/skythought/tools/tasks/aime/aime_handler.py index 5b7b6d7..17bf62a 100644 --- a/skythought/tools/tasks/aime/aime_handler.py +++ b/skythought/tools/tasks/aime/aime_handler.py @@ -1,16 +1,17 @@ +from typing import Dict from tasks.math.math_handler import MathTaskHandler from util.model_utils import MODEL_TO_NAME class AIMETaskHandler(MathTaskHandler): - def generate_prompt(self, problem, model): + def generate_prompt(self, problem: Dict, model): if MODEL_TO_NAME[model] == "Sky-T1-32B-Preview": return self.task_config.templating_parameters["sky_template"].format( - **problem + prompt=problem["problem"] ) else: return self.task_config.templating_parameters["regular_template"].format( - **problem + prompt=problem["problem"] ) def make_conversations(self, data, system_prompt, model=None): diff --git a/skythought/tools/tasks/common.py b/skythought/tools/tasks/common.py index 49d88e7..160a8fa 100644 --- a/skythought/tools/tasks/common.py +++ b/skythought/tools/tasks/common.py @@ -13,7 +13,7 @@ class TaskConfig(BaseModel): dataset_path: str dataset_source: Optional[str] = None dataset_split: str - dataset_kwargs: Optional[Dict[str, Any]] = None + dataset_kwargs: Dict[str, Any] = Field(default_factory=dict) question_key: str templating_parameters: Dict[str, str] = Field(default_factory=dict) # Optional, unused for now diff --git a/skythought/tools/tasks/math/math500.yaml b/skythought/tools/tasks/math/math500.yaml index 6eb73f2..1fe89e2 100644 --- a/skythought/tools/tasks/math/math500.yaml +++ b/skythought/tools/tasks/math/math500.yaml @@ -4,9 +4,9 @@ dataset_source: null # which subset on huggingface question_key: problem dataset_split: test templating_parameters: - template: "Return your final response within \\boxed{{}}. {problem}" - # optional. Not supported yet. -fewshot_config: - - question: ... - - target: ... -num_fewshot: 0 + template: "Return your final response within \\boxed{{}}. {problem}" +# optional. Not supported yet. +# fewshot_config: +# - question: ... +# - target: ... +# num_fewshot: 0 diff --git a/skythought/tools/tasks/math/math_handler.py b/skythought/tools/tasks/math/math_handler.py index 3d9a550..e7faa9b 100644 --- a/skythought/tools/tasks/math/math_handler.py +++ b/skythought/tools/tasks/math/math_handler.py @@ -53,5 +53,5 @@ def process_remaining_data(self, train_data, results): def load_and_filter_dataset( self, start, end, split=None, source=None, filter_difficulty=None, args=None ): - dataset = self.load_dataset(source=source, split=split) + dataset = self.load_dataset(source=source, split=split).to_pandas() return dataset.iloc[start:end] if end > 0 else dataset.iloc[start:] From 3188aebd373f5b56df48c8ec6dc3adbf6c7ccc82 Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Sat, 25 Jan 2025 20:35:01 +0000 Subject: [PATCH 16/47] add answer_key Signed-off-by: SumanthRH --- skythought/tools/tasks/__init__.py | 7 +++++++ skythought/tools/tasks/aime/aime.yaml | 1 + skythought/tools/tasks/amc23/amc23.yaml | 1 + skythought/tools/tasks/apps/apps.yaml | 1 + skythought/tools/tasks/arc/arc_c.yaml | 1 + skythought/tools/tasks/arc/arc_handler.py | 4 ++-- skythought/tools/tasks/common.py | 2 ++ .../tools/tasks/gpqa_diamond/gpqa_diamond.yaml | 1 + .../tasks/gpqa_diamond/gpqa_diamond_handler.py | 2 +- skythought/tools/tasks/gsm8k/gsm8k.yaml | 1 + skythought/tools/tasks/gsm8k/gsm8k_handler.py | 2 +- .../tools/tasks/livecodebench/livecodebench.yaml | 1 + skythought/tools/tasks/math/math500.yaml | 3 ++- skythought/tools/tasks/math/math_handler.py | 2 +- .../tools/tasks/minervamath/minervamath.yaml | 8 ++++++++ .../tools/tasks/minervamath/minervamath_handler.py | 8 +------- skythought/tools/tasks/mmlu/mmlu.yaml | 6 ++++-- skythought/tools/tasks/mmlu/mmlu_handler.py | 2 +- skythought/tools/tasks/mmlu/mmlu_pro.yaml | 2 ++ skythought/tools/tasks/numina/numina.yaml | 1 + skythought/tools/tasks/numina/numina_handler.py | 2 +- .../tasks/olympiadbench/olympiadbench_handler.py | 14 +------------- .../tasks/olympiadbench/olympiadbench_math_en.yaml | 8 ++++++++ skythought/tools/tasks/taco/taco.yaml | 6 ++++-- 24 files changed, 54 insertions(+), 32 deletions(-) create mode 100644 skythought/tools/tasks/minervamath/minervamath.yaml create mode 100644 skythought/tools/tasks/olympiadbench/olympiadbench_math_en.yaml diff --git a/skythought/tools/tasks/__init__.py b/skythought/tools/tasks/__init__.py index 80f357a..222bcb2 100644 --- a/skythought/tools/tasks/__init__.py +++ b/skythought/tools/tasks/__init__.py @@ -10,6 +10,8 @@ from .mmlu.mmlu_handler import MMLUProTaskHandler, MMLUTaskHandler from .numina.numina_handler import NUMINATaskHandler from .taco.taco_handler import TACOTaskHandler +from .minervamath.minervamath_handler import MinervaMathTaskHandler +from .olympiadbench.olympiadbench_handler import OlympiadBenchMathTaskHandler TASK_HANDLER_MAP = { "numina": NUMINATaskHandler, @@ -24,6 +26,8 @@ "gsm8k": GSM8KTaskHandler, "arc_c": ARCChallengeTaskHandler, "amc23": AMC23TaskHandler, + "minervamath": MinervaMathTaskHandler, + "olympiadbench_math": OlympiadBenchMathTaskHandler, } __all__ = [ @@ -41,5 +45,8 @@ ARCChallengeTaskHandler, TaskHandler, MathTaskHandler, + OlympiadBenchMathTaskHandler, + MinervaMathTaskHandler, + TaskConfig, TASK_HANDLER_MAP, ] diff --git a/skythought/tools/tasks/aime/aime.yaml b/skythought/tools/tasks/aime/aime.yaml index 5c2aa83..8df89f0 100644 --- a/skythought/tools/tasks/aime/aime.yaml +++ b/skythought/tools/tasks/aime/aime.yaml @@ -2,6 +2,7 @@ handler: aime dataset_path: AI-MO/aimo-validation-aime dataset_split: train question_key: problem +answer_key: answer templating_parameters: regular_template: "Return your final response within \\boxed{{}}. {prompt}" sky_template: "{prompt}\nReturn your final response within \\boxed{{}}" \ No newline at end of file diff --git a/skythought/tools/tasks/amc23/amc23.yaml b/skythought/tools/tasks/amc23/amc23.yaml index 627e31d..7f5186d 100644 --- a/skythought/tools/tasks/amc23/amc23.yaml +++ b/skythought/tools/tasks/amc23/amc23.yaml @@ -4,4 +4,5 @@ dataset_kwargs: trust_remote_code: true dataset_split: train question_key: problem +answer_key: answer difficulty: null diff --git a/skythought/tools/tasks/apps/apps.yaml b/skythought/tools/tasks/apps/apps.yaml index 1ee98df..4fca2aa 100644 --- a/skythought/tools/tasks/apps/apps.yaml +++ b/skythought/tools/tasks/apps/apps.yaml @@ -4,6 +4,7 @@ dataset_kwargs: trust_remote_code: true dataset_split: train question_key: question +answer_key: null difficulty: null templating_parameters: with_fn_name_template: "Generate an executable Python function generated from the given prompt. The function should take stdin as input and print the output. Simply call the function after the definition. {prompt}" diff --git a/skythought/tools/tasks/arc/arc_c.yaml b/skythought/tools/tasks/arc/arc_c.yaml index 126e165..09f83e9 100644 --- a/skythought/tools/tasks/arc/arc_c.yaml +++ b/skythought/tools/tasks/arc/arc_c.yaml @@ -3,6 +3,7 @@ dataset_path: allenai/ai2_arc dataset_source: ARC-Challenge dataset_split: train question_key: question +answer_key: answerKey templating_parameters: # We combine choices for a question into choices_text entry in the dataset template: "Given the following question and four candidate answers (A, B, C and D), choose the best answer. Your response should end with \"The best answer is [the_answer_letter]\" where [the_answer_letter] is one of the four letter choice (A, B, C, or D).\n{question}\n{choices_text}" \ No newline at end of file diff --git a/skythought/tools/tasks/arc/arc_handler.py b/skythought/tools/tasks/arc/arc_handler.py index d403f46..2b00bfe 100644 --- a/skythought/tools/tasks/arc/arc_handler.py +++ b/skythought/tools/tasks/arc/arc_handler.py @@ -28,9 +28,9 @@ def generate_prompt(self, problem): return full_prompt def check_correctness(self, problem: Dict[str, Any], generation: str) -> bool: - gt_answer = problem["answerKey"] + gt_answer = problem[self.task_config.answer_key] if gt_answer not in self.canonical_options: - gt_answer = self.canonical_options[int(problem["answerKey"]) - 1] + gt_answer = self.canonical_options[int(problem[self.task_config.answer_key]) - 1] model_answer = self.get_answer(generation) return model_answer == gt_answer diff --git a/skythought/tools/tasks/common.py b/skythought/tools/tasks/common.py index 160a8fa..ebac3b3 100644 --- a/skythought/tools/tasks/common.py +++ b/skythought/tools/tasks/common.py @@ -15,6 +15,8 @@ class TaskConfig(BaseModel): dataset_split: str dataset_kwargs: Dict[str, Any] = Field(default_factory=dict) question_key: str + # Optional answer key for datasets with a single correct answer + answer_key: Optional[str] = None templating_parameters: Dict[str, str] = Field(default_factory=dict) # Optional, unused for now fewshot_config: List[Dict[str, Any]] = Field(default_factory=list) diff --git a/skythought/tools/tasks/gpqa_diamond/gpqa_diamond.yaml b/skythought/tools/tasks/gpqa_diamond/gpqa_diamond.yaml index 3ba7c7c..940d960 100644 --- a/skythought/tools/tasks/gpqa_diamond/gpqa_diamond.yaml +++ b/skythought/tools/tasks/gpqa_diamond/gpqa_diamond.yaml @@ -3,6 +3,7 @@ dataset_path: Idavidrein/gpqa dataset_source: gpqa_diamond dataset_split: train question_key: Question +answer_key: Answer templating_parameters: # For GPQA, we combine the Question key and the multiple choice answers into a single `prompt` entry template: "Return your final response within \\boxed{{}} and only include the letter choice (A, B, C, or D) as your final response. {prompt}" \ No newline at end of file diff --git a/skythought/tools/tasks/gpqa_diamond/gpqa_diamond_handler.py b/skythought/tools/tasks/gpqa_diamond/gpqa_diamond_handler.py index 017e8c3..6055956 100644 --- a/skythought/tools/tasks/gpqa_diamond/gpqa_diamond_handler.py +++ b/skythought/tools/tasks/gpqa_diamond/gpqa_diamond_handler.py @@ -30,7 +30,7 @@ def update_results(self, problem, response): def check_correctness(self, problem, generation): pred = get_multiple_choice_answer(generation) - answer = problem["Answer"] + answer = problem[self.task_config.answer_key] return answer == pred def get_multiple_choice_answers(self, data): diff --git a/skythought/tools/tasks/gsm8k/gsm8k.yaml b/skythought/tools/tasks/gsm8k/gsm8k.yaml index 995106f..2ef5012 100644 --- a/skythought/tools/tasks/gsm8k/gsm8k.yaml +++ b/skythought/tools/tasks/gsm8k/gsm8k.yaml @@ -3,6 +3,7 @@ dataset_path: "openai/gsm8k" dataset_source: main dataset_split: test question_key: question +answer_key: answer templating_parameters: template: "Given the following problem, reason and give a final answer to the problem.\nProblem: {question}\nYour response should end with \"The final answer is [answer]\" where [answer] is the response to the problem." diff --git a/skythought/tools/tasks/gsm8k/gsm8k_handler.py b/skythought/tools/tasks/gsm8k/gsm8k_handler.py index b2042e6..7820bb1 100644 --- a/skythought/tools/tasks/gsm8k/gsm8k_handler.py +++ b/skythought/tools/tasks/gsm8k/gsm8k_handler.py @@ -16,7 +16,7 @@ def generate_prompt(self, problem): return self.task_config.templating_parameters["template"].format(**problem) def check_correctness(self, problem: Dict[str, Any], generation: str) -> bool: - gt_answer = self.extract_gt_answer(problem["answer"]) + gt_answer = self.extract_gt_answer(problem[self.task_config.answer_key]) model_answer = extract_answer(generation) model_answer = self.sanitize_answer(model_answer) return model_answer == gt_answer diff --git a/skythought/tools/tasks/livecodebench/livecodebench.yaml b/skythought/tools/tasks/livecodebench/livecodebench.yaml index 427aadc..a1fd7fd 100644 --- a/skythought/tools/tasks/livecodebench/livecodebench.yaml +++ b/skythought/tools/tasks/livecodebench/livecodebench.yaml @@ -6,6 +6,7 @@ dataset_kwargs: version_tag: release_v2 trust_remote_code: true question_key: task_id +answer_key: null templating_parameters: stdin_template: "Generate an executable Python function generated from the given prompt. The function should take stdin as input and print the output. Simply call the function after the definition. {prompt}" non_stdin_template: "Generate an executable Python function generated from the given prompt. Return the function body without invoking it at the final solution. {prompt}" diff --git a/skythought/tools/tasks/math/math500.yaml b/skythought/tools/tasks/math/math500.yaml index 1fe89e2..43c0e82 100644 --- a/skythought/tools/tasks/math/math500.yaml +++ b/skythought/tools/tasks/math/math500.yaml @@ -1,7 +1,8 @@ -handler: math500 +handler: math dataset_path: "qq8933/MATH500" # repo ID in huggingface dataset_source: null # which subset on huggingface question_key: problem +answer_key: answer dataset_split: test templating_parameters: template: "Return your final response within \\boxed{{}}. {problem}" diff --git a/skythought/tools/tasks/math/math_handler.py b/skythought/tools/tasks/math/math_handler.py index 6d98811..c281deb 100644 --- a/skythought/tools/tasks/math/math_handler.py +++ b/skythought/tools/tasks/math/math_handler.py @@ -7,7 +7,7 @@ def generate_prompt(self, problem): return self.task_config.templating_parameters["template"].format(**problem) def check_correctness(self, problem, generation): - answer = strip_answer_string(problem["answer"]) + answer = strip_answer_string(problem[self.task_config.answer_key]) pred = extract_answer(generation) pred = strip_answer_string(pred) return math_equal(pred, answer) diff --git a/skythought/tools/tasks/minervamath/minervamath.yaml b/skythought/tools/tasks/minervamath/minervamath.yaml new file mode 100644 index 0000000..85ba7aa --- /dev/null +++ b/skythought/tools/tasks/minervamath/minervamath.yaml @@ -0,0 +1,8 @@ +handler: math +dataset_path: "svc-huggingface/minerva-math" # repo ID in huggingface +dataset_source: null # which subset on huggingface +question_key: problem +answer_key: solution +dataset_split: test +templating_parameters: + template: "Return your final response within \\boxed{{}}. {problem}" \ No newline at end of file diff --git a/skythought/tools/tasks/minervamath/minervamath_handler.py b/skythought/tools/tasks/minervamath/minervamath_handler.py index de59d65..5742268 100644 --- a/skythought/tools/tasks/minervamath/minervamath_handler.py +++ b/skythought/tools/tasks/minervamath/minervamath_handler.py @@ -3,15 +3,9 @@ from tasks.math.math_handler import MathTaskHandler class MinervaMathTaskHandler(MathTaskHandler): - def __init__(self): - self.dataset = "svc-huggingface/minerva-math" - - @staticmethod - def get_question_key(): - return "problem" def check_correctness(self, problem, generation): - answer = extract_answer(problem["solution"]) + answer = extract_answer(problem[self.task_config.answer_key]) answer = strip_answer_string(answer) pred = extract_answer(generation) diff --git a/skythought/tools/tasks/mmlu/mmlu.yaml b/skythought/tools/tasks/mmlu/mmlu.yaml index 0f4f1d8..ad98fd5 100644 --- a/skythought/tools/tasks/mmlu/mmlu.yaml +++ b/skythought/tools/tasks/mmlu/mmlu.yaml @@ -1,6 +1,8 @@ handler: mmlu -dataset_path: "cais/mmlu" -dataset_source: default +dataset_path: cais/mmlu +dataset_source: all dataset_split: test +question_key: question +answer_key: answer templating_parameters: template: "Return your final response within \\boxed{{}}. {prompt}" diff --git a/skythought/tools/tasks/mmlu/mmlu_handler.py b/skythought/tools/tasks/mmlu/mmlu_handler.py index 16991d9..bfa712f 100644 --- a/skythought/tools/tasks/mmlu/mmlu_handler.py +++ b/skythought/tools/tasks/mmlu/mmlu_handler.py @@ -12,7 +12,7 @@ def generate_prompt(self, prompt): def check_correctness(self, problem, generation): pred = get_multiple_choice_answer(generation) abcd = "ABCD" - answer = abcd[problem["answer"]] + answer = abcd[problem[self.task_config.answer_key]] return answer == pred def update_results(self, problem, response): diff --git a/skythought/tools/tasks/mmlu/mmlu_pro.yaml b/skythought/tools/tasks/mmlu/mmlu_pro.yaml index b744e80..4b88e92 100644 --- a/skythought/tools/tasks/mmlu/mmlu_pro.yaml +++ b/skythought/tools/tasks/mmlu/mmlu_pro.yaml @@ -2,5 +2,7 @@ handler: mmlu_pro dataset_path: TIGER-Lab/MMLU-Pro dataset_source: default dataset_split: test +question_key: question +answer_key: answer templating_parameters: template: "Return your final response within \\boxed{{}}. {prompt}" diff --git a/skythought/tools/tasks/numina/numina.yaml b/skythought/tools/tasks/numina/numina.yaml index 3eed955..e1d8e3d 100644 --- a/skythought/tools/tasks/numina/numina.yaml +++ b/skythought/tools/tasks/numina/numina.yaml @@ -3,5 +3,6 @@ dataset_path: "AI-MO/NuminaMath-CoT" dataset_source: default dataset_split: train question_key: problem +answer_key: solution templating_parameters: template: "Return your final response within \\boxed{{}}. {prompt}" diff --git a/skythought/tools/tasks/numina/numina_handler.py b/skythought/tools/tasks/numina/numina_handler.py index f19742c..f2cdafa 100644 --- a/skythought/tools/tasks/numina/numina_handler.py +++ b/skythought/tools/tasks/numina/numina_handler.py @@ -19,7 +19,7 @@ def generate_prompt(self, prompt): @timeout(5) # Add timeout of 5 seconds def check_correctness(self, problem, generation): - solution = extract_answer(problem["solution"]) + solution = extract_answer(problem[self.task_config.answer_key]) solution = strip_answer_string(solution) pred = extract_answer(generation) pred = strip_answer_string(pred) diff --git a/skythought/tools/tasks/olympiadbench/olympiadbench_handler.py b/skythought/tools/tasks/olympiadbench/olympiadbench_handler.py index 6049539..e79d871 100644 --- a/skythought/tools/tasks/olympiadbench/olympiadbench_handler.py +++ b/skythought/tools/tasks/olympiadbench/olympiadbench_handler.py @@ -4,21 +4,9 @@ class OlympiadBenchMathTaskHandler(MathTaskHandler): - def __init__(self): - self.dataset = "Hothan/OlympiadBench" - self.source = "OE_TO_maths_en_COMP" - - @staticmethod - def get_question_key(): - return "question" - - def load_and_filter_dataset(self, start, end, split="train", source=None, filter_difficulty=False, args=None): - dataset = self.load_dataset(source=source, split=split).to_pandas() - return dataset.iloc[start:end] if end > 0 else dataset.iloc[start:] - def check_correctness(self, problem, generation): # all problems have final answer in a list - answer = strip_answer_string(problem["final_answer"][0]) + answer = strip_answer_string(problem[self.task_config.answer_key][0]) pred = extract_answer(generation) pred = strip_answer_string(pred) return math_equal(pred, answer) \ No newline at end of file diff --git a/skythought/tools/tasks/olympiadbench/olympiadbench_math_en.yaml b/skythought/tools/tasks/olympiadbench/olympiadbench_math_en.yaml new file mode 100644 index 0000000..8f3d81b --- /dev/null +++ b/skythought/tools/tasks/olympiadbench/olympiadbench_math_en.yaml @@ -0,0 +1,8 @@ +handler: olympiadbench_math +dataset_path: Hothan/OlympiadBench +dataset_source: OE_TO_maths_en_COMP +dataset_split: test +question_key: question +answer_key: final_answer +templating_parameters: + template: "Return your final response within \\boxed{{}}. {prompt}" diff --git a/skythought/tools/tasks/taco/taco.yaml b/skythought/tools/tasks/taco/taco.yaml index 5c02b4e..ee33d03 100644 --- a/skythought/tools/tasks/taco/taco.yaml +++ b/skythought/tools/tasks/taco/taco.yaml @@ -1,9 +1,11 @@ handler: taco dataset_path: "BAAI/TACO" -dataset_source: default -dataset_split: ALL +dataset_source: ALL +dataset_split: train dataset_kwargs: trust_remote_code: true +question_key: question +answer_key: null templating_parameters: initial_template: "\nQUESTION:\n{prompt}" # Add starter code to initial template From 612ecaf6c8a15bf1881302cfd5b0534f0b9be201 Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Mon, 27 Jan 2025 06:52:28 +0000 Subject: [PATCH 17/47] x Signed-off-by: SumanthRH --- skythought/tools/tasks/__init__.py | 4 ++-- skythought/tools/tasks/amc23/amc23.yaml | 2 ++ skythought/tools/tasks/apps/apps_handler.py | 9 ++++++++- skythought/tools/tasks/arc/arc_handler.py | 2 +- skythought/tools/tasks/{common.py => base.py} | 0 .../tasks/gpqa_diamond/gpqa_diamond_handler.py | 2 +- skythought/tools/tasks/gsm8k/gsm8k_handler.py | 2 +- .../tasks/livecodebench/livecodebench_handler.py | 2 +- skythought/tools/tasks/math/math_handler.py | 2 +- .../tools/tasks/minervamath/minervamath_handler.py | 2 +- skythought/tools/tasks/mmlu/mmlu_handler.py | 14 ++++++-------- skythought/tools/tasks/numina/numina_handler.py | 2 +- .../tasks/olympiadbench/olympiadbench_handler.py | 2 +- .../tasks/olympiadbench/olympiadbench_math_en.yaml | 4 ++-- skythought/tools/tasks/taco/taco.yaml | 1 + skythought/tools/tasks/taco/taco_handler.py | 5 +++-- 16 files changed, 32 insertions(+), 23 deletions(-) rename skythought/tools/tasks/{common.py => base.py} (100%) diff --git a/skythought/tools/tasks/__init__.py b/skythought/tools/tasks/__init__.py index 222bcb2..97cc533 100644 --- a/skythought/tools/tasks/__init__.py +++ b/skythought/tools/tasks/__init__.py @@ -2,7 +2,7 @@ from .amc23.amc23_handler import AMC23TaskHandler from .apps.apps_handler import APPSTaskHandler from .arc.arc_handler import ARCChallengeTaskHandler -from .common import TaskHandler, TaskConfig +from .base import TaskHandler, TaskConfig from .gpqa_diamond.gpqa_diamond_handler import GPQADiamondTaskHandler from .gsm8k.gsm8k_handler import GSM8KTaskHandler from .livecodebench.livecodebench_handler import LiveCodeBenchTaskHandler @@ -27,7 +27,7 @@ "arc_c": ARCChallengeTaskHandler, "amc23": AMC23TaskHandler, "minervamath": MinervaMathTaskHandler, - "olympiadbench_math": OlympiadBenchMathTaskHandler, + "olympiadbench_math_en": OlympiadBenchMathTaskHandler, } __all__ = [ diff --git a/skythought/tools/tasks/amc23/amc23.yaml b/skythought/tools/tasks/amc23/amc23.yaml index 7f5186d..c86ed2a 100644 --- a/skythought/tools/tasks/amc23/amc23.yaml +++ b/skythought/tools/tasks/amc23/amc23.yaml @@ -6,3 +6,5 @@ dataset_split: train question_key: problem answer_key: answer difficulty: null +templating_parameters: + template: "Return your final response within \\boxed{{}}. {problem}" \ No newline at end of file diff --git a/skythought/tools/tasks/apps/apps_handler.py b/skythought/tools/tasks/apps/apps_handler.py index fc87649..1e9e3f6 100644 --- a/skythought/tools/tasks/apps/apps_handler.py +++ b/skythought/tools/tasks/apps/apps_handler.py @@ -2,16 +2,23 @@ import json import multiprocessing from multiprocessing import Manager +from typing import Optional import numpy as np from tasks.apps.apps_util import run_test as apps_run_test from util.common import has_code -from ..common import TaskHandler +from ..base import TaskHandler, TaskConfig +class APPSTaskConfig(TaskConfig): + # by default, no filter on difficulty + difficulty: Optional[str] = None + class APPSTaskHandler(TaskHandler): + task_config_cls = APPSTaskConfig + def generate_prompt(self, test_case, prompt, starter_code=None): if not test_case.get("fn_name"): _input = self.task_config.templating_parameters[ diff --git a/skythought/tools/tasks/arc/arc_handler.py b/skythought/tools/tasks/arc/arc_handler.py index 2b00bfe..50b8c64 100644 --- a/skythought/tools/tasks/arc/arc_handler.py +++ b/skythought/tools/tasks/arc/arc_handler.py @@ -1,7 +1,7 @@ import re from typing import Any, Dict -from tasks.common import TaskHandler, TaskConfig +from tasks.base import TaskHandler, TaskConfig from util.math_parsing_util import extract_answer diff --git a/skythought/tools/tasks/common.py b/skythought/tools/tasks/base.py similarity index 100% rename from skythought/tools/tasks/common.py rename to skythought/tools/tasks/base.py diff --git a/skythought/tools/tasks/gpqa_diamond/gpqa_diamond_handler.py b/skythought/tools/tasks/gpqa_diamond/gpqa_diamond_handler.py index 6055956..532305c 100644 --- a/skythought/tools/tasks/gpqa_diamond/gpqa_diamond_handler.py +++ b/skythought/tools/tasks/gpqa_diamond/gpqa_diamond_handler.py @@ -1,6 +1,6 @@ import random -from tasks.common import TaskHandler +from tasks.base import TaskHandler from util.math_parsing_util import get_multiple_choice_answer diff --git a/skythought/tools/tasks/gsm8k/gsm8k_handler.py b/skythought/tools/tasks/gsm8k/gsm8k_handler.py index 7820bb1..4a4780a 100644 --- a/skythought/tools/tasks/gsm8k/gsm8k_handler.py +++ b/skythought/tools/tasks/gsm8k/gsm8k_handler.py @@ -1,7 +1,7 @@ import re from typing import Any, Dict -from tasks.common import TaskHandler, TaskConfig +from tasks.base import TaskHandler, TaskConfig from util.math_parsing_util import extract_answer diff --git a/skythought/tools/tasks/livecodebench/livecodebench_handler.py b/skythought/tools/tasks/livecodebench/livecodebench_handler.py index 6957236..fcbd2a0 100644 --- a/skythought/tools/tasks/livecodebench/livecodebench_handler.py +++ b/skythought/tools/tasks/livecodebench/livecodebench_handler.py @@ -1,7 +1,7 @@ import copy from typing import Dict, Optional -from tasks.common import TaskConfig, TaskHandler +from tasks.base import TaskConfig, TaskHandler from tasks.livecodebench.livecodebench_util import ( map_to_example, post_process_code, diff --git a/skythought/tools/tasks/math/math_handler.py b/skythought/tools/tasks/math/math_handler.py index c281deb..1cb1c9c 100644 --- a/skythought/tools/tasks/math/math_handler.py +++ b/skythought/tools/tasks/math/math_handler.py @@ -1,4 +1,4 @@ -from tasks.common import TaskHandler +from tasks.base import TaskHandler from util.math_parsing_util import extract_answer, math_equal, strip_answer_string diff --git a/skythought/tools/tasks/minervamath/minervamath_handler.py b/skythought/tools/tasks/minervamath/minervamath_handler.py index 5742268..033d6db 100644 --- a/skythought/tools/tasks/minervamath/minervamath_handler.py +++ b/skythought/tools/tasks/minervamath/minervamath_handler.py @@ -1,4 +1,4 @@ -from tasks.common import TaskHandler +from tasks.base import TaskHandler from util.math_parsing_util import extract_answer, math_equal, strip_answer_string from tasks.math.math_handler import MathTaskHandler diff --git a/skythought/tools/tasks/mmlu/mmlu_handler.py b/skythought/tools/tasks/mmlu/mmlu_handler.py index bfa712f..67a39b7 100644 --- a/skythought/tools/tasks/mmlu/mmlu_handler.py +++ b/skythought/tools/tasks/mmlu/mmlu_handler.py @@ -1,8 +1,7 @@ from datasets import load_dataset from util.math_parsing_util import get_multiple_choice_answer, mmlu_pro_extract_answer - -from ..common import TaskHandler, TaskConfig +from tasks.base import TaskHandler, TaskConfig class MMLUTaskHandler(TaskHandler): @@ -63,11 +62,10 @@ def process_remaining_data(self, train_data, results): ] def load_and_filter_dataset( - self, start, end, split="test", source=None, filter_difficulty=None, args=None + self, start, end, split=None, source=None, filter_difficulty=None, args=None ): - dataset = load_dataset(self.dataset, "all") - train_data = dataset[split].to_pandas() - return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] + dataset = self.load_dataset(source=source, split=split).to_pandas() + return dataset.iloc[start:end] if end > 0 else dataset.iloc[start:] class MMLUProTaskHandler(MMLUTaskHandler): @@ -108,7 +106,7 @@ def get_multiple_choice_answers(self, problem): return f"Answer Choices: {options}" def load_and_filter_dataset( - self, start, end, split="test", source=None, filter_difficulty=None, args=None + self, start, end, split=None, source=None, filter_difficulty=None, args=None ): - dataset = self.load_dataset(source=source, split=split) + dataset = self.load_dataset(source=source, split=split).to_pandas() return dataset.iloc[start:end] if end > 0 else dataset.iloc[start:] diff --git a/skythought/tools/tasks/numina/numina_handler.py b/skythought/tools/tasks/numina/numina_handler.py index f2cdafa..3aa4f59 100644 --- a/skythought/tools/tasks/numina/numina_handler.py +++ b/skythought/tools/tasks/numina/numina_handler.py @@ -2,7 +2,7 @@ from datasets import load_dataset -from tasks.common import TaskConfig, TaskHandler +from tasks.base import TaskConfig, TaskHandler from util.common import TimeoutException, timeout from util.math_parsing_util import extract_answer, math_equal, strip_answer_string diff --git a/skythought/tools/tasks/olympiadbench/olympiadbench_handler.py b/skythought/tools/tasks/olympiadbench/olympiadbench_handler.py index e79d871..cd54780 100644 --- a/skythought/tools/tasks/olympiadbench/olympiadbench_handler.py +++ b/skythought/tools/tasks/olympiadbench/olympiadbench_handler.py @@ -1,4 +1,4 @@ -from tasks.common import TaskHandler +from tasks.base import TaskHandler from util.math_parsing_util import extract_answer, math_equal, strip_answer_string from tasks.math.math_handler import MathTaskHandler diff --git a/skythought/tools/tasks/olympiadbench/olympiadbench_math_en.yaml b/skythought/tools/tasks/olympiadbench/olympiadbench_math_en.yaml index 8f3d81b..b532ed5 100644 --- a/skythought/tools/tasks/olympiadbench/olympiadbench_math_en.yaml +++ b/skythought/tools/tasks/olympiadbench/olympiadbench_math_en.yaml @@ -1,8 +1,8 @@ handler: olympiadbench_math dataset_path: Hothan/OlympiadBench dataset_source: OE_TO_maths_en_COMP -dataset_split: test +dataset_split: train question_key: question answer_key: final_answer templating_parameters: - template: "Return your final response within \\boxed{{}}. {prompt}" + template: "Return your final response within \\boxed{{}}. {question}" diff --git a/skythought/tools/tasks/taco/taco.yaml b/skythought/tools/tasks/taco/taco.yaml index ee33d03..41d1cd3 100644 --- a/skythought/tools/tasks/taco/taco.yaml +++ b/skythought/tools/tasks/taco/taco.yaml @@ -6,6 +6,7 @@ dataset_kwargs: trust_remote_code: true question_key: question answer_key: null +difficulty: null templating_parameters: initial_template: "\nQUESTION:\n{prompt}" # Add starter code to initial template diff --git a/skythought/tools/tasks/taco/taco_handler.py b/skythought/tools/tasks/taco/taco_handler.py index 2cc222a..e052d26 100644 --- a/skythought/tools/tasks/taco/taco_handler.py +++ b/skythought/tools/tasks/taco/taco_handler.py @@ -8,7 +8,7 @@ from tasks.taco.taco_util import run_test as taco_run_test from util.common import has_code -from ..common import TaskConfig, TaskHandler +from ..base import TaskConfig, TaskHandler class TACOTaskConfig(TaskConfig): @@ -16,6 +16,7 @@ class TACOTaskConfig(TaskConfig): class TACOTaskHandler(TaskHandler): + task_config_cls = TACOTaskConfig def generate_prompt(self, prompt, starter_code=None, fn_name=None): _input = self.task_config.templating_parameters["initial_template"].format( @@ -113,7 +114,7 @@ def make_conversations(self, data, system_prompt, model=None): return conversations def load_and_filter_dataset( - self, start, end, split="train", source=None, filter_difficulty=None, args=None + self, start, end, split=None, source=None, filter_difficulty=None, args=None ): dataset = self.load_dataset(source=source, split=split).to_pandas() if filter_difficulty or self.task_config.difficulty: From 77b8ba3793489a51e06ef5fa571076360645800e Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Mon, 27 Jan 2025 06:53:30 +0000 Subject: [PATCH 18/47] x Signed-off-by: SumanthRH --- skythought/tools/eval.py | 7 +++++-- skythought/tools/inference_and_check.py | 3 ++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/skythought/tools/eval.py b/skythought/tools/eval.py index 71f69b0..b751aa9 100644 --- a/skythought/tools/eval.py +++ b/skythought/tools/eval.py @@ -1,13 +1,14 @@ import argparse import json -import subprocess import os +import subprocess from tasks.task_util import get_tasks module_dir = os.path.dirname(os.path.abspath(__file__)) TASK_NAMES_TO_YAML = get_tasks(os.path.join(module_dir, "tasks")) + def parse_arguments(): parser = argparse.ArgumentParser( description="Process model path, prompt format, and evals to run." @@ -85,7 +86,9 @@ def main(): # Run the Python command for each eval and collect logs for eval_name in evals: eval_name = eval_name.lower() - assert eval_name in TASK_NAMES_TO_YAML.keys(), f"Task {eval_name} not found, should be one of {TASK_NAMES_TO_YAML.keys()}" + assert ( + eval_name in TASK_NAMES_TO_YAML.keys() + ), f"Task {eval_name} not found, should be one of {TASK_NAMES_TO_YAML.keys()}" command = [ "python", script_path, diff --git a/skythought/tools/inference_and_check.py b/skythought/tools/inference_and_check.py index 2fa4a16..456922c 100644 --- a/skythought/tools/inference_and_check.py +++ b/skythought/tools/inference_and_check.py @@ -17,6 +17,7 @@ module_dir = os.path.dirname(os.path.abspath(__file__)) TASK_NAMES_TO_YAML = get_tasks(os.path.join(module_dir, "tasks")) + class NumpyEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, np.ndarray): @@ -137,7 +138,7 @@ def perform_inference_and_check( prompt = conversations[idx][1]["content"] results[problem_key]["prompt"] = prompt results[problem_key]["input_conversation"] = conversations[idx] - + results[problem_key]["responses"][str(temp)] = response_entry if args.model.startswith("openai"): From 52a4ce7c6ff95b56188414d4228cb017069fa339 Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Mon, 27 Jan 2025 07:28:42 +0000 Subject: [PATCH 19/47] fixing pre-commit Signed-off-by: SumanthRH --- skythought/tools/.githooks/pre-commit | 11 +++-- skythought/tools/.pre-commit-config.yaml | 4 +- skythought/tools/format.sh | 7 ++- skythought/tools/inference_and_check.py | 62 +++++++++++++++--------- skythought/tools/pyproject.toml | 6 ++- skythought/tools/upload_hub.py | 1 + 6 files changed, 59 insertions(+), 32 deletions(-) diff --git a/skythought/tools/.githooks/pre-commit b/skythought/tools/.githooks/pre-commit index 927ae65..86bc517 100755 --- a/skythought/tools/.githooks/pre-commit +++ b/skythought/tools/.githooks/pre-commit @@ -1,10 +1,15 @@ set -e +git_root=$(git rev-parse --show-toplevel) +tools_relative=skythought/tools # Get tools directory path relative to git root -TOOLS_DIR=$(git rev-parse --show-toplevel)/skythought/tools +TOOLS_DIR=$git_root/$tools_relative # Only run pre-commit if changes are in tools/ # Run pre-commit from tools/ directory to use linting rules in this directory -if git diff --cached --name-only | grep "^skythought/tools/"; then +tools_files=$(git diff --cached --name-only --relative="$tools_relative" -- ./) + +if [ -n $tools_files ]; then cd $TOOLS_DIR; - pre-commit run --files $(git diff --cached --name-only | grep "^skythought/tools/") --config .pre-commit-config.yaml + # only get the diffs in tools/ and run pre-commit + pre-commit run --files $tools_files --config .pre-commit-config.yaml fi diff --git a/skythought/tools/.pre-commit-config.yaml b/skythought/tools/.pre-commit-config.yaml index 003a141..30d726a 100644 --- a/skythought/tools/.pre-commit-config.yaml +++ b/skythought/tools/.pre-commit-config.yaml @@ -1,12 +1,12 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.9 + rev: v0.9.3 hooks: - id: ruff args: [ --fix, --exit-non-zero-on-fix ] # Black needs to be ran after ruff with --fix - repo: https://github.com/psf/black - rev: 23.3.0 + rev: 24.10.0 hooks: - id: black \ No newline at end of file diff --git a/skythought/tools/format.sh b/skythought/tools/format.sh index 2296f90..a5fe91e 100644 --- a/skythought/tools/format.sh +++ b/skythought/tools/format.sh @@ -4,6 +4,11 @@ set -e # Get tools directory path relative to git root TOOLS_DIR=$(git rev-parse --show-toplevel)/skythought/tools +if [ ! -d "$TOOLS_DIR" ]; then + echo "Error: Tools directory not found at $TOOLS_DIR" + exit 1 +fi + if command -v uv >/dev/null 2>&1; then uv pip install -q pre-commit else @@ -17,4 +22,4 @@ chmod +x $HOOK_SCRIPT git config --local core.hooksPath "$TOOLS_DIR/.githooks" # pre-commit run --all-files always runs from the root directory. we run this only on tools/ for now. cd $TOOLS_DIR; -pre-commit run --files $TOOLS_DIR/* --config $TOOLS_DIR/.pre-commit-config.yaml \ No newline at end of file +pre-commit run --files ./* --config .pre-commit-config.yaml \ No newline at end of file diff --git a/skythought/tools/inference_and_check.py b/skythought/tools/inference_and_check.py index 456922c..98fb695 100644 --- a/skythought/tools/inference_and_check.py +++ b/skythought/tools/inference_and_check.py @@ -180,14 +180,14 @@ def perform_inference_and_check( token_dict = { "completion_tokens": sum(completion_tokens), "prompt_tokens": sum(prompt_tokens), - "avg_completion_tokens": round( - sum(completion_tokens) / len(completion_tokens), 3 - ) - if completion_tokens - else 0, - "avg_prompt_tokens": round(sum(prompt_tokens) / len(prompt_tokens), 3) - if prompt_tokens - else 0, + "avg_completion_tokens": ( + round(sum(completion_tokens) / len(completion_tokens), 3) + if completion_tokens + else 0 + ), + "avg_prompt_tokens": ( + round(sum(prompt_tokens) / len(prompt_tokens), 3) if prompt_tokens else 0 + ), } # Save the token usage dictionary to the result file @@ -231,9 +231,11 @@ def perform_check(handler: TaskHandler, temperatures, result_file, args): ( item, temp, - response_entry["processed_content"] - if processed - else response_entry["content"], + ( + response_entry["processed_content"] + if processed + else response_entry["content"] + ), sample_id, ) ) @@ -341,9 +343,11 @@ def perform_inference_and_save( completion_token = 0 for sample_idx in range(args.n): response_entry = { - "content": response.choices[0].message.content.strip() - if args.model.startswith("openai") - else response.outputs[sample_idx].text.strip(), + "content": ( + response.choices[0].message.content.strip() + if args.model.startswith("openai") + else response.outputs[sample_idx].text.strip() + ), "correctness": None, "reason": None, } @@ -397,14 +401,14 @@ def perform_inference_and_save( token_dict = { "completion_tokens": sum(completion_tokens), "prompt_tokens": sum(prompt_tokens), - "avg_completion_tokens": round( - sum(completion_tokens) / len(completion_tokens), 3 - ) - if completion_tokens - else 0, - "avg_prompt_tokens": round(sum(prompt_tokens) / len(prompt_tokens), 3) - if prompt_tokens - else 0, + "avg_completion_tokens": ( + round(sum(completion_tokens) / len(completion_tokens), 3) + if completion_tokens + else 0 + ), + "avg_prompt_tokens": ( + round(sum(prompt_tokens) / len(prompt_tokens), 3) if prompt_tokens else 0 + ), } # Save the token usage dictionary to the result file @@ -527,7 +531,10 @@ def main(): args.math_difficulty_lower_bound is not None or args.math_difficulty_upper_bound is not None ): - converted_file = f"{args.result_dir}/converted_{MODEL_TO_NAME[args.model]}_{args.task}_{args.split}_{args.source}_{args.start}_{args.end}_{args.math_difficulty_lower_bound}_{args.math_difficulty_upper_bound}.json" + converted_file = ( + f"{args.result_dir}/converted_{MODEL_TO_NAME[args.model]}_{args.task}_{args.split}_{args.source}_{args.start}_{args.end}" + + f"_{args.math_difficulty_lower_bound}_{args.math_difficulty_upper_bound}.json" + ) else: converted_file = f"{args.result_dir}/converted_{MODEL_TO_NAME[args.model]}_{args.task}_{args.split}_{args.source}_{args.start}_{args.end}.json" if os.path.exists(converted_file): @@ -552,8 +559,15 @@ def main(): else LLM(model=args.model, tensor_parallel_size=args.tp) ) system_prompt = SYSTEM_PROMPT[args.model] + perform_inference_and_check( - handler, temperatures, max_tokens, result_file, llm, system_prompt, args + handler, + temperatures, + max_tokens, + result_file, + llm, + system_prompt, + args, ) diff --git a/skythought/tools/pyproject.toml b/skythought/tools/pyproject.toml index 88570fd..58eac71 100644 --- a/skythought/tools/pyproject.toml +++ b/skythought/tools/pyproject.toml @@ -1,3 +1,5 @@ [tool.ruff] -lint.select = ["E", "F", "I", "ASYNC", "B"] -line-length = 300 \ No newline at end of file +line-length = 160 + +[tool.ruff.lint] +extend-select = ["E", "F", "I", "ASYNC", "B"] \ No newline at end of file diff --git a/skythought/tools/upload_hub.py b/skythought/tools/upload_hub.py index 8b4d052..3020648 100644 --- a/skythought/tools/upload_hub.py +++ b/skythought/tools/upload_hub.py @@ -5,6 +5,7 @@ Usage: python upload_hub.py --model-path ~/model_weights/Sky-T1 --hub-repo-id NovaSky-AI/Sky-T1 --private """ + import argparse import tempfile From d21d68ec9a090ac05d32f718f95427be4192137c Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Mon, 27 Jan 2025 20:29:51 +0000 Subject: [PATCH 20/47] x Signed-off-by: SumanthRH --- skythought/tools/.githooks/pre-commit | 10 --- skythought/tools/.pre-commit-config.yaml | 12 ---- skythought/tools/format.sh | 20 ------ skythought/tools/inference_and_check.py | 67 +++++++++++-------- skythought/tools/pyproject.toml | 3 - skythought/tools/tasks/apps/apps.yaml | 4 +- skythought/tools/tasks/apps/apps_handler.py | 23 ++----- skythought/tools/tasks/base.py | 18 +++-- .../tasks/livecodebench/livecodebench.yaml | 3 +- .../livecodebench/livecodebench_handler.py | 23 +++---- skythought/tools/tasks/numina/numina.yaml | 2 + .../tools/tasks/numina/numina_handler.py | 10 +-- skythought/tools/tasks/taco/taco.yaml | 3 +- skythought/tools/tasks/taco/taco_handler.py | 17 ++--- skythought/tools/upload_hub.py | 1 + skythought/train/LLaMA-Factory/setup.py | 8 +++ 16 files changed, 92 insertions(+), 132 deletions(-) delete mode 100755 skythought/tools/.githooks/pre-commit delete mode 100644 skythought/tools/.pre-commit-config.yaml delete mode 100644 skythought/tools/format.sh delete mode 100644 skythought/tools/pyproject.toml diff --git a/skythought/tools/.githooks/pre-commit b/skythought/tools/.githooks/pre-commit deleted file mode 100755 index 927ae65..0000000 --- a/skythought/tools/.githooks/pre-commit +++ /dev/null @@ -1,10 +0,0 @@ -set -e - -# Get tools directory path relative to git root -TOOLS_DIR=$(git rev-parse --show-toplevel)/skythought/tools -# Only run pre-commit if changes are in tools/ -# Run pre-commit from tools/ directory to use linting rules in this directory -if git diff --cached --name-only | grep "^skythought/tools/"; then - cd $TOOLS_DIR; - pre-commit run --files $(git diff --cached --name-only | grep "^skythought/tools/") --config .pre-commit-config.yaml -fi diff --git a/skythought/tools/.pre-commit-config.yaml b/skythought/tools/.pre-commit-config.yaml deleted file mode 100644 index 003a141..0000000 --- a/skythought/tools/.pre-commit-config.yaml +++ /dev/null @@ -1,12 +0,0 @@ -repos: - - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.9 - hooks: - - id: ruff - args: [ --fix, --exit-non-zero-on-fix ] - - # Black needs to be ran after ruff with --fix - - repo: https://github.com/psf/black - rev: 23.3.0 - hooks: - - id: black \ No newline at end of file diff --git a/skythought/tools/format.sh b/skythought/tools/format.sh deleted file mode 100644 index 2296f90..0000000 --- a/skythought/tools/format.sh +++ /dev/null @@ -1,20 +0,0 @@ - -set -e - -# Get tools directory path relative to git root -TOOLS_DIR=$(git rev-parse --show-toplevel)/skythought/tools - -if command -v uv >/dev/null 2>&1; then - uv pip install -q pre-commit -else - pip install -q pre-commit -fi - -# Hook file should be executable -HOOK_SCRIPT=$TOOLS_DIR/.githooks/pre-commit -chmod +x $HOOK_SCRIPT - -git config --local core.hooksPath "$TOOLS_DIR/.githooks" -# pre-commit run --all-files always runs from the root directory. we run this only on tools/ for now. -cd $TOOLS_DIR; -pre-commit run --files $TOOLS_DIR/* --config $TOOLS_DIR/.pre-commit-config.yaml \ No newline at end of file diff --git a/skythought/tools/inference_and_check.py b/skythought/tools/inference_and_check.py index 456922c..3a3595f 100644 --- a/skythought/tools/inference_and_check.py +++ b/skythought/tools/inference_and_check.py @@ -7,12 +7,11 @@ import numpy as np from openai import OpenAI -from tqdm import tqdm -from vllm import LLM, SamplingParams - from tasks import TASK_HANDLER_MAP, NUMINATaskHandler, TaskHandler from tasks.task_util import get_tasks +from tqdm import tqdm from util.model_utils import MODEL_TO_NAME, SYSTEM_PROMPT +from vllm import LLM, SamplingParams module_dir = os.path.dirname(os.path.abspath(__file__)) TASK_NAMES_TO_YAML = get_tasks(os.path.join(module_dir, "tasks")) @@ -180,14 +179,14 @@ def perform_inference_and_check( token_dict = { "completion_tokens": sum(completion_tokens), "prompt_tokens": sum(prompt_tokens), - "avg_completion_tokens": round( - sum(completion_tokens) / len(completion_tokens), 3 - ) - if completion_tokens - else 0, - "avg_prompt_tokens": round(sum(prompt_tokens) / len(prompt_tokens), 3) - if prompt_tokens - else 0, + "avg_completion_tokens": ( + round(sum(completion_tokens) / len(completion_tokens), 3) + if completion_tokens + else 0 + ), + "avg_prompt_tokens": ( + round(sum(prompt_tokens) / len(prompt_tokens), 3) if prompt_tokens else 0 + ), } # Save the token usage dictionary to the result file @@ -231,9 +230,11 @@ def perform_check(handler: TaskHandler, temperatures, result_file, args): ( item, temp, - response_entry["processed_content"] - if processed - else response_entry["content"], + ( + response_entry["processed_content"] + if processed + else response_entry["content"] + ), sample_id, ) ) @@ -341,9 +342,11 @@ def perform_inference_and_save( completion_token = 0 for sample_idx in range(args.n): response_entry = { - "content": response.choices[0].message.content.strip() - if args.model.startswith("openai") - else response.outputs[sample_idx].text.strip(), + "content": ( + response.choices[0].message.content.strip() + if args.model.startswith("openai") + else response.outputs[sample_idx].text.strip() + ), "correctness": None, "reason": None, } @@ -397,14 +400,14 @@ def perform_inference_and_save( token_dict = { "completion_tokens": sum(completion_tokens), "prompt_tokens": sum(prompt_tokens), - "avg_completion_tokens": round( - sum(completion_tokens) / len(completion_tokens), 3 - ) - if completion_tokens - else 0, - "avg_prompt_tokens": round(sum(prompt_tokens) / len(prompt_tokens), 3) - if prompt_tokens - else 0, + "avg_completion_tokens": ( + round(sum(completion_tokens) / len(completion_tokens), 3) + if completion_tokens + else 0 + ), + "avg_prompt_tokens": ( + round(sum(prompt_tokens) / len(prompt_tokens), 3) if prompt_tokens else 0 + ), } # Save the token usage dictionary to the result file @@ -527,7 +530,10 @@ def main(): args.math_difficulty_lower_bound is not None or args.math_difficulty_upper_bound is not None ): - converted_file = f"{args.result_dir}/converted_{MODEL_TO_NAME[args.model]}_{args.task}_{args.split}_{args.source}_{args.start}_{args.end}_{args.math_difficulty_lower_bound}_{args.math_difficulty_upper_bound}.json" + converted_file = ( + f"{args.result_dir}/converted_{MODEL_TO_NAME[args.model]}_{args.task}_{args.split}_{args.source}_{args.start}_{args.end}" + + f"_{args.math_difficulty_lower_bound}_{args.math_difficulty_upper_bound}.json" + ) else: converted_file = f"{args.result_dir}/converted_{MODEL_TO_NAME[args.model]}_{args.task}_{args.split}_{args.source}_{args.start}_{args.end}.json" if os.path.exists(converted_file): @@ -552,8 +558,15 @@ def main(): else LLM(model=args.model, tensor_parallel_size=args.tp) ) system_prompt = SYSTEM_PROMPT[args.model] + perform_inference_and_check( - handler, temperatures, max_tokens, result_file, llm, system_prompt, args + handler, + temperatures, + max_tokens, + result_file, + llm, + system_prompt, + args, ) diff --git a/skythought/tools/pyproject.toml b/skythought/tools/pyproject.toml deleted file mode 100644 index 88570fd..0000000 --- a/skythought/tools/pyproject.toml +++ /dev/null @@ -1,3 +0,0 @@ -[tool.ruff] -lint.select = ["E", "F", "I", "ASYNC", "B"] -line-length = 300 \ No newline at end of file diff --git a/skythought/tools/tasks/apps/apps.yaml b/skythought/tools/tasks/apps/apps.yaml index 4fca2aa..fa5f38c 100644 --- a/skythought/tools/tasks/apps/apps.yaml +++ b/skythought/tools/tasks/apps/apps.yaml @@ -10,4 +10,6 @@ templating_parameters: with_fn_name_template: "Generate an executable Python function generated from the given prompt. The function should take stdin as input and print the output. Simply call the function after the definition. {prompt}" without_fn_name_template: "Generate an executable Python function generated from the given prompt. Return the function body without invoking it at the final solution. {prompt}" # Add starter code on top of the initial template - with_starter_code_template: "{input}\n{starter_code}" \ No newline at end of file + with_starter_code_template: "{input}\n{starter_code}" +# preprocess_config; +# difficulty: easy # optional filter config \ No newline at end of file diff --git a/skythought/tools/tasks/apps/apps_handler.py b/skythought/tools/tasks/apps/apps_handler.py index 1e9e3f6..14e332d 100644 --- a/skythought/tools/tasks/apps/apps_handler.py +++ b/skythought/tools/tasks/apps/apps_handler.py @@ -2,36 +2,25 @@ import json import multiprocessing from multiprocessing import Manager -from typing import Optional import numpy as np - from tasks.apps.apps_util import run_test as apps_run_test from util.common import has_code -from ..base import TaskHandler, TaskConfig - +from ..base import TaskHandler -class APPSTaskConfig(TaskConfig): - # by default, no filter on difficulty - difficulty: Optional[str] = None class APPSTaskHandler(TaskHandler): - task_config_cls = APPSTaskConfig def generate_prompt(self, test_case, prompt, starter_code=None): if not test_case.get("fn_name"): _input = self.task_config.templating_parameters[ "with_fn_name_template" - ].format( - prompt=prompt - ) + ].format(prompt=prompt) else: _input = self.task_config.templating_parameters[ "without_fn_name_template" - ].format( - prompt=prompt - ) + ].format(prompt=prompt) if starter_code is not None: _input = self.task_config.templating_parameters[ "with_starter_code_template" @@ -112,9 +101,11 @@ def load_and_filter_dataset( self, start, end, split=None, source=None, filter_difficulty=None, args=None ): train_data = self.load_dataset(source=source, split=split).to_pandas() - if filter_difficulty or self.task_config.difficulty: + if filter_difficulty or self.task_config.preprocess_config.difficulty: difficulty = ( - self.task_config.difficulty if not filter_difficulty else filter_difficulty + self.task_config.preprocess_config.difficulty + if not filter_difficulty + else filter_difficulty ) train_data = train_data.filter(lambda x: x["difficulty"] == difficulty) diff --git a/skythought/tools/tasks/base.py b/skythought/tools/tasks/base.py index ebac3b3..b91e8e4 100644 --- a/skythought/tools/tasks/base.py +++ b/skythought/tools/tasks/base.py @@ -8,6 +8,10 @@ from pydantic import BaseModel, Field +class PreprocessConfig(BaseModel): + difficulty: str + + class TaskConfig(BaseModel): handler: str dataset_path: str @@ -16,17 +20,13 @@ class TaskConfig(BaseModel): dataset_kwargs: Dict[str, Any] = Field(default_factory=dict) question_key: str # Optional answer key for datasets with a single correct answer - answer_key: Optional[str] = None + answer_key: Optional[str] = None templating_parameters: Dict[str, str] = Field(default_factory=dict) # Optional, unused for now fewshot_config: List[Dict[str, Any]] = Field(default_factory=list) num_fewshot: int = 0 - @property - def handler_cls(self): - from tasks import TASK_HANDLER_MAP - - return TASK_HANDLER_MAP[self.handler] + preprocess_config: Optional[PreprocessConfig] = None @classmethod def from_yaml(cls, yaml_file_path) -> "TaskConfig": @@ -35,16 +35,14 @@ def from_yaml(cls, yaml_file_path) -> "TaskConfig": return cls(**config_dict) - class TaskHandler: - task_config_cls = TaskConfig def __init__(self, task_config: TaskConfig): self.task_config = task_config - + @classmethod def from_config_path(cls, config_path: str) -> "TaskHandler": - task_config = cls.task_config_cls.from_yaml(config_path) + task_config = TaskConfig.from_yaml(config_path) return cls(task_config) @property diff --git a/skythought/tools/tasks/livecodebench/livecodebench.yaml b/skythought/tools/tasks/livecodebench/livecodebench.yaml index a1fd7fd..a8347fd 100644 --- a/skythought/tools/tasks/livecodebench/livecodebench.yaml +++ b/skythought/tools/tasks/livecodebench/livecodebench.yaml @@ -10,4 +10,5 @@ answer_key: null templating_parameters: stdin_template: "Generate an executable Python function generated from the given prompt. The function should take stdin as input and print the output. Simply call the function after the definition. {prompt}" non_stdin_template: "Generate an executable Python function generated from the given prompt. Return the function body without invoking it at the final solution. {prompt}" -difficulty: null # use all by default +preprocess_config: null +# difficulty: easy # use all by default \ No newline at end of file diff --git a/skythought/tools/tasks/livecodebench/livecodebench_handler.py b/skythought/tools/tasks/livecodebench/livecodebench_handler.py index fcbd2a0..7fb13f1 100644 --- a/skythought/tools/tasks/livecodebench/livecodebench_handler.py +++ b/skythought/tools/tasks/livecodebench/livecodebench_handler.py @@ -1,7 +1,8 @@ import copy -from typing import Dict, Optional +from typing import Dict -from tasks.base import TaskConfig, TaskHandler +from datasets import Dataset as HFDataset +from tasks.base import TaskHandler from tasks.livecodebench.livecodebench_util import ( map_to_example, post_process_code, @@ -9,15 +10,9 @@ unsafe_lcb_runTests, ) from util.common import has_code -from datasets import Dataset as HFDataset - - -class LiveCodeBenchTaskConfig(TaskConfig): - difficulty: Optional[str] = None # use all by default class LiveCodeBenchTaskHandler(TaskHandler): - task_config_cls = LiveCodeBenchTaskConfig def generate_prompt(self, problem): if problem["is_stdin"]: @@ -106,13 +101,17 @@ def load_and_filter_dataset( ): dataset: HFDataset = self.load_dataset(source=source, split=split) # Filter by CLI or config - if filter_difficulty or self.task_config.difficulty: - difficulty = source if filter_difficulty else self.task_config.difficulty + if filter_difficulty or self.task_config.preprocess_config.difficulty: + difficulty = ( + source + if filter_difficulty + else self.task_config.preprocess_config.difficulty + ) dataset = dataset.filter( lambda example: example["difficulty"] == difficulty ) - # We use a lower writer_batch_size to avoid pyarrow issues. JSON entries with LiveCodeBench are large. - # See: https://github.com/NovaSky-AI/SkyThought/pull/45 for details. + # We use a lower writer_batch_size to avoid pyarrow issues. JSON entries with LiveCodeBench are large. + # See: https://github.com/NovaSky-AI/SkyThought/pull/45 for details. dataset = dataset.map( lambda example: { "private_test_cases": translate_private_test_cases( diff --git a/skythought/tools/tasks/numina/numina.yaml b/skythought/tools/tasks/numina/numina.yaml index e1d8e3d..1a0f72d 100644 --- a/skythought/tools/tasks/numina/numina.yaml +++ b/skythought/tools/tasks/numina/numina.yaml @@ -6,3 +6,5 @@ question_key: problem answer_key: solution templating_parameters: template: "Return your final response within \\boxed{{}}. {prompt}" +preprocess_config: + difficulty: null \ No newline at end of file diff --git a/skythought/tools/tasks/numina/numina_handler.py b/skythought/tools/tasks/numina/numina_handler.py index 3aa4f59..b00dcb4 100644 --- a/skythought/tools/tasks/numina/numina_handler.py +++ b/skythought/tools/tasks/numina/numina_handler.py @@ -1,18 +1,10 @@ -from typing import Optional - from datasets import load_dataset - -from tasks.base import TaskConfig, TaskHandler +from tasks.base import TaskHandler from util.common import TimeoutException, timeout from util.math_parsing_util import extract_answer, math_equal, strip_answer_string -class NUMINATaskConfig(TaskConfig): - difficulty: Optional[str] = None # use all by default - - class NUMINATaskHandler(TaskHandler): - task_config_cls = NUMINATaskConfig def generate_prompt(self, prompt): return self.task_config.templating_parameters["template"].format(prompt=prompt) diff --git a/skythought/tools/tasks/taco/taco.yaml b/skythought/tools/tasks/taco/taco.yaml index 41d1cd3..c0521cb 100644 --- a/skythought/tools/tasks/taco/taco.yaml +++ b/skythought/tools/tasks/taco/taco.yaml @@ -6,7 +6,6 @@ dataset_kwargs: trust_remote_code: true question_key: question answer_key: null -difficulty: null templating_parameters: initial_template: "\nQUESTION:\n{prompt}" # Add starter code to initial template @@ -15,4 +14,6 @@ templating_parameters: stdin_template: "{input}\nUse Standard Input format\nANSWER:\n" # call template is used when there is starter code or fn_name call_template: "{input}\nUse Call-Based format\nANSWER:\n" +preprocess_config: + difficulty: null diff --git a/skythought/tools/tasks/taco/taco_handler.py b/skythought/tools/tasks/taco/taco_handler.py index e052d26..256ffed 100644 --- a/skythought/tools/tasks/taco/taco_handler.py +++ b/skythought/tools/tasks/taco/taco_handler.py @@ -1,22 +1,15 @@ import json import multiprocessing from multiprocessing import Manager -from typing import Optional import numpy as np - from tasks.taco.taco_util import run_test as taco_run_test from util.common import has_code -from ..base import TaskConfig, TaskHandler - - -class TACOTaskConfig(TaskConfig): - difficulty: Optional[str] = None # use all by default +from ..base import TaskHandler class TACOTaskHandler(TaskHandler): - task_config_cls = TACOTaskConfig def generate_prompt(self, prompt, starter_code=None, fn_name=None): _input = self.task_config.templating_parameters["initial_template"].format( @@ -117,8 +110,12 @@ def load_and_filter_dataset( self, start, end, split=None, source=None, filter_difficulty=None, args=None ): dataset = self.load_dataset(source=source, split=split).to_pandas() - if filter_difficulty or self.task_config.difficulty: - difficulty = source if filter_difficulty else self.task_config.difficulty + if filter_difficulty or self.task_config.preprocess_config.difficulty: + difficulty = ( + source + if filter_difficulty + else self.task_config.preprocess_config.difficulty + ) dataset = dataset.filter( lambda example: example["difficulty"] == difficulty ) diff --git a/skythought/tools/upload_hub.py b/skythought/tools/upload_hub.py index 8b4d052..3020648 100644 --- a/skythought/tools/upload_hub.py +++ b/skythought/tools/upload_hub.py @@ -5,6 +5,7 @@ Usage: python upload_hub.py --model-path ~/model_weights/Sky-T1 --hub-repo-id NovaSky-AI/Sky-T1 --private """ + import argparse import tempfile diff --git a/skythought/train/LLaMA-Factory/setup.py b/skythought/train/LLaMA-Factory/setup.py index 862e9b9..3329f07 100644 --- a/skythought/train/LLaMA-Factory/setup.py +++ b/skythought/train/LLaMA-Factory/setup.py @@ -102,3 +102,11 @@ def main(): if __name__ == "__main__": main() + + + + + + + + From 2069df7a074bdbc83e57e35971973bd08494e05f Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Mon, 27 Jan 2025 21:38:54 +0000 Subject: [PATCH 21/47] move some stuff; init tests; init package skyevals Signed-off-by: SumanthRH --- .githooks/pre-commit | 14 ++++++ .pre-commit-config.yaml | 14 ++++++ format.sh | 21 ++++++++ pyproject.toml | 5 ++ setup.py | 10 ++++ skythought/__init__.py | 0 skythought/tools/combine_data.py | 2 +- skythought/tools/convert_format.py | 2 +- skythought/tools/convert_to_data.py | 2 +- skythought/tools/inference_and_check.py | 3 +- skythought/tools/label_math_difficulty.py | 2 +- skythought/tools/response_rewrite.py | 4 +- skythought/tools/tasks/aime/aime_handler.py | 9 ++-- skythought/tools/tasks/amc23/amc23_handler.py | 2 +- skythought/tools/tasks/apps/apps_handler.py | 4 +- skythought/tools/tasks/arc/arc_handler.py | 8 +-- .../gpqa_diamond/gpqa_diamond_handler.py | 4 +- skythought/tools/tasks/gsm8k/gsm8k_handler.py | 4 +- .../livecodebench/livecodebench_handler.py | 7 +-- skythought/tools/tasks/math/math_handler.py | 6 +-- .../tasks/minervamath/minervamath_handler.py | 8 +-- skythought/tools/tasks/mmlu/mmlu_handler.py | 18 ++++--- .../tools/tasks/numina/numina_handler.py | 7 +-- .../olympiadbench/olympiadbench_handler.py | 11 ++-- skythought/tools/tasks/taco/taco_handler.py | 4 +- skythought/tools/util/__init__.py | 0 tests/tools/preprocessing.py | 50 +++++++++++++++++++ 27 files changed, 171 insertions(+), 50 deletions(-) create mode 100755 .githooks/pre-commit create mode 100644 .pre-commit-config.yaml create mode 100644 format.sh create mode 100644 pyproject.toml create mode 100644 setup.py create mode 100644 skythought/__init__.py create mode 100644 skythought/tools/util/__init__.py create mode 100644 tests/tools/preprocessing.py diff --git a/.githooks/pre-commit b/.githooks/pre-commit new file mode 100755 index 0000000..cd2a157 --- /dev/null +++ b/.githooks/pre-commit @@ -0,0 +1,14 @@ +set -e + +GIT_ROOT=$(git rev-parse --show-toplevel) + +tools_relative=skythought/tools +# Only run pre-commit if changes are in tools/ +# Run pre-commit from tools/ directory to use linting rules in this directory +tools_files=$(git diff --cached --name-only --relative="./" -- $tools_relative) + +if [ -n $tools_files ]; then + cd $TOOLS_DIR; + # only get the diffs in tools/ and run pre-commit + pre-commit run --files $tools_files --config .pre-commit-config.yaml +fi diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..f978811 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,14 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.9.3 + hooks: + - id: ruff + args: [ --fix, --exit-non-zero-on-fix ] + exclude: ^skythought/train + + # Black needs to be ran after ruff with --fix + - repo: https://github.com/psf/black + rev: 24.10.0 + hooks: + - id: black + exclude: ^skythought/train diff --git a/format.sh b/format.sh new file mode 100644 index 0000000..bbc3b5f --- /dev/null +++ b/format.sh @@ -0,0 +1,21 @@ + +set -e + +# Get tools directory path relative to git root +GIT_ROOT=$(git rev-parse --show-toplevel) +TOOLS_RELATIVE=skythought/tools +TOOLS_DIR=$GIT_ROOT/$TOOLS_RELATIVE + +if command -v uv >/dev/null 2>&1; then + uv pip install -q pre-commit +else + pip install -q pre-commit +fi + +# Hook file should be executable +HOOK_SCRIPT=$GIT_ROOT/.githooks/pre-commit +chmod +x $HOOK_SCRIPT + +# git config --local core.hooksPath ".githooks" +# pre-commit run --all-files always runs from the root directory. we run this only on tools/ for now. +pre-commit run --files $GIT_ROOT/skythought/tools --config .pre-commit-config.yaml \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..58eac71 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,5 @@ +[tool.ruff] +line-length = 160 + +[tool.ruff.lint] +extend-select = ["E", "F", "I", "ASYNC", "B"] \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..5d1c372 --- /dev/null +++ b/setup.py @@ -0,0 +1,10 @@ +# setup module skyevals in tools directory +import setuptools + +setuptools.setup( + name="skyevals", + version="0.0.1", + package_dir={"skyevals": "skythought/tools"}, # map skyevals to skythought/tools + packages=["skyevals"] + + [f"skyevals.{pkg}" for pkg in setuptools.find_packages(where="skythought/tools")], +) diff --git a/skythought/__init__.py b/skythought/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/skythought/tools/combine_data.py b/skythought/tools/combine_data.py index 2ab86e9..8106501 100644 --- a/skythought/tools/combine_data.py +++ b/skythought/tools/combine_data.py @@ -1,7 +1,7 @@ import json import random -from util.prompts import system_prompt +from .util.prompts import system_prompt still2_jsonl_file = "../../data/public_long_form_thought_data_5k.jsonl" code_json_file = "../../data/converted_apps_long_form_thought_data_5k.json" diff --git a/skythought/tools/convert_format.py b/skythought/tools/convert_format.py index ead9d3a..17fe9f3 100644 --- a/skythought/tools/convert_format.py +++ b/skythought/tools/convert_format.py @@ -8,7 +8,7 @@ import openai from tqdm import tqdm -from util.prompts import convert_prompt, convert_prompt_example +from .util.prompts import convert_prompt, convert_prompt_example global args diff --git a/skythought/tools/convert_to_data.py b/skythought/tools/convert_to_data.py index 152f9a8..890c4f4 100644 --- a/skythought/tools/convert_to_data.py +++ b/skythought/tools/convert_to_data.py @@ -2,7 +2,7 @@ import json import os -from util.prompts import system_prompt +from .util.prompts import system_prompt def main(): diff --git a/skythought/tools/inference_and_check.py b/skythought/tools/inference_and_check.py index 3a3595f..7a7d5f1 100644 --- a/skythought/tools/inference_and_check.py +++ b/skythought/tools/inference_and_check.py @@ -10,9 +10,10 @@ from tasks import TASK_HANDLER_MAP, NUMINATaskHandler, TaskHandler from tasks.task_util import get_tasks from tqdm import tqdm -from util.model_utils import MODEL_TO_NAME, SYSTEM_PROMPT from vllm import LLM, SamplingParams +from .util.model_utils import MODEL_TO_NAME, SYSTEM_PROMPT + module_dir = os.path.dirname(os.path.abspath(__file__)) TASK_NAMES_TO_YAML = get_tasks(os.path.join(module_dir, "tasks")) diff --git a/skythought/tools/label_math_difficulty.py b/skythought/tools/label_math_difficulty.py index 3ab5804..d810a83 100644 --- a/skythought/tools/label_math_difficulty.py +++ b/skythought/tools/label_math_difficulty.py @@ -11,7 +11,7 @@ from datasets import load_dataset from tqdm import tqdm -from util.prompts import aops_criteria, grading_prompt +from .util.prompts import aops_criteria, grading_prompt # Function to set the OpenAI API key diff --git a/skythought/tools/response_rewrite.py b/skythought/tools/response_rewrite.py index eaa307a..04e888f 100644 --- a/skythought/tools/response_rewrite.py +++ b/skythought/tools/response_rewrite.py @@ -3,11 +3,11 @@ import os import random -from skythought.tools.util.math_parsing_util import strip_answer_string from tqdm import tqdm from vllm import LLM, SamplingParams -from util.model_utils import ( +from .util.math_parsing_util import strip_answer_string +from .util.model_utils import ( SUBPROBLEM_SPLIT_PROMPT, SUBSOLUTION_EXTRACTION_PROMPT, SYSTEM_PROMPT, diff --git a/skythought/tools/tasks/aime/aime_handler.py b/skythought/tools/tasks/aime/aime_handler.py index 17bf62a..afe7ea6 100644 --- a/skythought/tools/tasks/aime/aime_handler.py +++ b/skythought/tools/tasks/aime/aime_handler.py @@ -1,6 +1,7 @@ -from typing import Dict -from tasks.math.math_handler import MathTaskHandler -from util.model_utils import MODEL_TO_NAME +from typing import Dict + +from ...util.model_utils import MODEL_TO_NAME +from ..math.math_handler import MathTaskHandler class AIMETaskHandler(MathTaskHandler): @@ -11,7 +12,7 @@ def generate_prompt(self, problem: Dict, model): ) else: return self.task_config.templating_parameters["regular_template"].format( - prompt=problem["problem"] + prompt=problem["problem"] ) def make_conversations(self, data, system_prompt, model=None): diff --git a/skythought/tools/tasks/amc23/amc23_handler.py b/skythought/tools/tasks/amc23/amc23_handler.py index ff598fb..5213559 100644 --- a/skythought/tools/tasks/amc23/amc23_handler.py +++ b/skythought/tools/tasks/amc23/amc23_handler.py @@ -1,4 +1,4 @@ -from tasks.math.math_handler import MathTaskHandler +from ..math.math_handler import MathTaskHandler class AMC23TaskHandler(MathTaskHandler): diff --git a/skythought/tools/tasks/apps/apps_handler.py b/skythought/tools/tasks/apps/apps_handler.py index 14e332d..ad56e01 100644 --- a/skythought/tools/tasks/apps/apps_handler.py +++ b/skythought/tools/tasks/apps/apps_handler.py @@ -4,9 +4,9 @@ from multiprocessing import Manager import numpy as np -from tasks.apps.apps_util import run_test as apps_run_test -from util.common import has_code +from ...util.common import has_code +from ..apps.apps_util import run_test as apps_run_test from ..base import TaskHandler diff --git a/skythought/tools/tasks/arc/arc_handler.py b/skythought/tools/tasks/arc/arc_handler.py index 50b8c64..34c88f4 100644 --- a/skythought/tools/tasks/arc/arc_handler.py +++ b/skythought/tools/tasks/arc/arc_handler.py @@ -1,8 +1,8 @@ import re from typing import Any, Dict -from tasks.base import TaskHandler, TaskConfig -from util.math_parsing_util import extract_answer +from ...util.math_parsing_util import extract_answer +from ..base import TaskConfig, TaskHandler class ARCChallengeTaskHandler(TaskHandler): @@ -30,7 +30,9 @@ def generate_prompt(self, problem): def check_correctness(self, problem: Dict[str, Any], generation: str) -> bool: gt_answer = problem[self.task_config.answer_key] if gt_answer not in self.canonical_options: - gt_answer = self.canonical_options[int(problem[self.task_config.answer_key]) - 1] + gt_answer = self.canonical_options[ + int(problem[self.task_config.answer_key]) - 1 + ] model_answer = self.get_answer(generation) return model_answer == gt_answer diff --git a/skythought/tools/tasks/gpqa_diamond/gpqa_diamond_handler.py b/skythought/tools/tasks/gpqa_diamond/gpqa_diamond_handler.py index 532305c..bff304b 100644 --- a/skythought/tools/tasks/gpqa_diamond/gpqa_diamond_handler.py +++ b/skythought/tools/tasks/gpqa_diamond/gpqa_diamond_handler.py @@ -1,7 +1,7 @@ import random -from tasks.base import TaskHandler -from util.math_parsing_util import get_multiple_choice_answer +from ...util.math_parsing_util import get_multiple_choice_answer +from ..base import TaskHandler class GPQADiamondTaskHandler(TaskHandler): diff --git a/skythought/tools/tasks/gsm8k/gsm8k_handler.py b/skythought/tools/tasks/gsm8k/gsm8k_handler.py index 4a4780a..364785b 100644 --- a/skythought/tools/tasks/gsm8k/gsm8k_handler.py +++ b/skythought/tools/tasks/gsm8k/gsm8k_handler.py @@ -1,8 +1,8 @@ import re from typing import Any, Dict -from tasks.base import TaskHandler, TaskConfig -from util.math_parsing_util import extract_answer +from ...util.math_parsing_util import extract_answer +from ..base import TaskConfig, TaskHandler class GSM8KTaskHandler(TaskHandler): diff --git a/skythought/tools/tasks/livecodebench/livecodebench_handler.py b/skythought/tools/tasks/livecodebench/livecodebench_handler.py index 7fb13f1..807147b 100644 --- a/skythought/tools/tasks/livecodebench/livecodebench_handler.py +++ b/skythought/tools/tasks/livecodebench/livecodebench_handler.py @@ -2,14 +2,15 @@ from typing import Dict from datasets import Dataset as HFDataset -from tasks.base import TaskHandler -from tasks.livecodebench.livecodebench_util import ( + +from ...util.common import has_code +from ..base import TaskHandler +from .livecodebench_util import ( map_to_example, post_process_code, translate_private_test_cases, unsafe_lcb_runTests, ) -from util.common import has_code class LiveCodeBenchTaskHandler(TaskHandler): diff --git a/skythought/tools/tasks/math/math_handler.py b/skythought/tools/tasks/math/math_handler.py index 1cb1c9c..adc7949 100644 --- a/skythought/tools/tasks/math/math_handler.py +++ b/skythought/tools/tasks/math/math_handler.py @@ -1,5 +1,5 @@ -from tasks.base import TaskHandler -from util.math_parsing_util import extract_answer, math_equal, strip_answer_string +from ...util.math_parsing_util import extract_answer, math_equal, strip_answer_string +from ..base import TaskHandler class MathTaskHandler(TaskHandler): @@ -54,4 +54,4 @@ def load_and_filter_dataset( self, start, end, split=None, source=None, filter_difficulty=None, args=None ): dataset = self.load_dataset(source=source, split=split).to_pandas() - return dataset.iloc[start:end] if end > 0 else dataset.iloc[start:] \ No newline at end of file + return dataset.iloc[start:end] if end > 0 else dataset.iloc[start:] diff --git a/skythought/tools/tasks/minervamath/minervamath_handler.py b/skythought/tools/tasks/minervamath/minervamath_handler.py index 033d6db..b82d014 100644 --- a/skythought/tools/tasks/minervamath/minervamath_handler.py +++ b/skythought/tools/tasks/minervamath/minervamath_handler.py @@ -1,6 +1,6 @@ -from tasks.base import TaskHandler -from util.math_parsing_util import extract_answer, math_equal, strip_answer_string -from tasks.math.math_handler import MathTaskHandler +from ...util.math_parsing_util import extract_answer, math_equal, strip_answer_string +from ..math.math_handler import MathTaskHandler + class MinervaMathTaskHandler(MathTaskHandler): @@ -10,4 +10,4 @@ def check_correctness(self, problem, generation): pred = extract_answer(generation) pred = strip_answer_string(pred) - return math_equal(pred, answer) \ No newline at end of file + return math_equal(pred, answer) diff --git a/skythought/tools/tasks/mmlu/mmlu_handler.py b/skythought/tools/tasks/mmlu/mmlu_handler.py index 67a39b7..f11a2b6 100644 --- a/skythought/tools/tasks/mmlu/mmlu_handler.py +++ b/skythought/tools/tasks/mmlu/mmlu_handler.py @@ -1,7 +1,8 @@ -from datasets import load_dataset - -from util.math_parsing_util import get_multiple_choice_answer, mmlu_pro_extract_answer -from tasks.base import TaskHandler, TaskConfig +from ...util.math_parsing_util import ( + get_multiple_choice_answer, + mmlu_pro_extract_answer, +) +from ..base import TaskConfig, TaskHandler class MMLUTaskHandler(TaskHandler): @@ -34,10 +35,11 @@ def update_results(self, problem, response): def get_multiple_choice_answers(self, problem): options = problem["choices"] - for i, (label, option) in enumerate(zip("ABCD", options)): - options[i] = f"({label}) {str(option).strip()}" - options = " ".join(options) - return f"Answer Choices: {options}" + options_str = "" + for _, (label, option) in enumerate(zip("ABCD", options)): + options_str += f"({label}) {str(option).strip()} " + options_str = options_str[:-1] # remove the last space + return f"Answer Choices: {options_str}" def make_conversations(self, data, system_prompt, model=None): conversations = [] diff --git a/skythought/tools/tasks/numina/numina_handler.py b/skythought/tools/tasks/numina/numina_handler.py index b00dcb4..dad0fc1 100644 --- a/skythought/tools/tasks/numina/numina_handler.py +++ b/skythought/tools/tasks/numina/numina_handler.py @@ -1,7 +1,8 @@ from datasets import load_dataset -from tasks.base import TaskHandler -from util.common import TimeoutException, timeout -from util.math_parsing_util import extract_answer, math_equal, strip_answer_string + +from ...util.common import TimeoutException, timeout +from ...util.math_parsing_util import extract_answer, math_equal, strip_answer_string +from ..base import TaskHandler class NUMINATaskHandler(TaskHandler): diff --git a/skythought/tools/tasks/olympiadbench/olympiadbench_handler.py b/skythought/tools/tasks/olympiadbench/olympiadbench_handler.py index cd54780..6807bbd 100644 --- a/skythought/tools/tasks/olympiadbench/olympiadbench_handler.py +++ b/skythought/tools/tasks/olympiadbench/olympiadbench_handler.py @@ -1,12 +1,11 @@ -from tasks.base import TaskHandler -from util.math_parsing_util import extract_answer, math_equal, strip_answer_string -from tasks.math.math_handler import MathTaskHandler +from ...util.math_parsing_util import extract_answer, math_equal, strip_answer_string +from ..math.math_handler import MathTaskHandler -class OlympiadBenchMathTaskHandler(MathTaskHandler): +class OlympiadBenchMathTaskHandler(MathTaskHandler): def check_correctness(self, problem, generation): - # all problems have final answer in a list + # all problems have final answer in a list answer = strip_answer_string(problem[self.task_config.answer_key][0]) pred = extract_answer(generation) pred = strip_answer_string(pred) - return math_equal(pred, answer) \ No newline at end of file + return math_equal(pred, answer) diff --git a/skythought/tools/tasks/taco/taco_handler.py b/skythought/tools/tasks/taco/taco_handler.py index 256ffed..4a615dc 100644 --- a/skythought/tools/tasks/taco/taco_handler.py +++ b/skythought/tools/tasks/taco/taco_handler.py @@ -3,10 +3,10 @@ from multiprocessing import Manager import numpy as np -from tasks.taco.taco_util import run_test as taco_run_test -from util.common import has_code +from ...util.common import has_code from ..base import TaskHandler +from .taco_util import run_test as taco_run_test class TACOTaskHandler(TaskHandler): diff --git a/skythought/tools/util/__init__.py b/skythought/tools/util/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/tools/preprocessing.py b/tests/tools/preprocessing.py new file mode 100644 index 0000000..24d64fc --- /dev/null +++ b/tests/tools/preprocessing.py @@ -0,0 +1,50 @@ +import pytest +from skyevals.tasks import MMLUTaskHandler, TaskConfig + +inputs = [ + ( + { + "is_stdin": False, + "question": "What is the capital of France?", + "choices": ["Paris", "London", "Berlin", "Madrid"], + "answer": "0", + }, + TaskConfig( + handler="dummy", + dataset_path="dummy", + dataset_split="dummy", + question_key="question", + answer_key="answer", + templating_parameters={ + "template": "Return your final response within \\boxed{{}}. {prompt}" + }, + ), + MMLUTaskHandler, + [ + {"role": "system", "content": "Please answer the following question:"}, + { + "role": "user", + "content": "Return your final response within \\boxed{}. What is the capital of France?\nAnswer Choices: (A) Paris (B) London (C) Berlin (D) Madrid", # noqa: E501 + }, + ], + ), +] + + +@pytest.mark.parametrize("row,config,handler_cls,expected_conversation", inputs) +def test_make_conversations(row, config, handler_cls, expected_conversation): + + # Expected system prompt + system_prompt = "Please answer the following question:" + + # Initialize the handler + handler = handler_cls(config) + + # Expected conversation format + # expected input + # Call make_conversations + conversations = handler.make_conversations([row], system_prompt) + # Assert the conversation is as expected + assert conversations == [ + expected_conversation + ], f"Expected conversation {expected_conversation} but got {conversations}." From 55b670cdbc84e7f82f1bbe4524d75c28939e13f2 Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Mon, 27 Jan 2025 21:45:14 +0000 Subject: [PATCH 22/47] rm llama factory change Signed-off-by: SumanthRH --- skythought/train/LLaMA-Factory/setup.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/skythought/train/LLaMA-Factory/setup.py b/skythought/train/LLaMA-Factory/setup.py index 3329f07..862e9b9 100644 --- a/skythought/train/LLaMA-Factory/setup.py +++ b/skythought/train/LLaMA-Factory/setup.py @@ -102,11 +102,3 @@ def main(): if __name__ == "__main__": main() - - - - - - - - From f2d88e927316ee4edce643437f35d44382f261dc Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Mon, 27 Jan 2025 21:47:58 +0000 Subject: [PATCH 23/47] merge issues Signed-off-by: SumanthRH --- format.sh | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/format.sh b/format.sh index 558555c..0c30a99 100644 --- a/format.sh +++ b/format.sh @@ -6,11 +6,6 @@ GIT_ROOT=$(git rev-parse --show-toplevel) TOOLS_RELATIVE=skythought/tools TOOLS_DIR=$GIT_ROOT/$TOOLS_RELATIVE -if [ ! -d "$TOOLS_DIR" ]; then - echo "Error: Tools directory not found at $TOOLS_DIR" - exit 1 -fi - if command -v uv >/dev/null 2>&1; then uv pip install -q pre-commit else @@ -21,11 +16,5 @@ fi HOOK_SCRIPT=$GIT_ROOT/.githooks/pre-commit chmod +x $HOOK_SCRIPT -# git config --local core.hooksPath ".githooks" # pre-commit run --all-files always runs from the root directory. we run this only on tools/ for now. -<<<<<<< HEAD:format.sh pre-commit run --files $GIT_ROOT/skythought/tools --config .pre-commit-config.yaml -======= -cd $TOOLS_DIR; -pre-commit run --files ./* --config .pre-commit-config.yaml ->>>>>>> 52a4ce7c6ff95b56188414d4228cb017069fa339:skythought/tools/format.sh From a9a93805698579d435f4b9667a2a751cdbf5659c Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Mon, 27 Jan 2025 23:09:13 +0000 Subject: [PATCH 24/47] more linting Signed-off-by: SumanthRH --- .githooks/pre-commit | 14 ----- .pre-commit-config.yaml | 3 +- format.sh | 6 +-- pyproject.toml | 3 +- setup.py | 13 +++-- .../.gitattributes | 0 .../{tools => skythought_evals}/README.md | 0 .../util => skythought_evals}/__init__.py | 0 .../base_instruct_evals.md | 0 .../combine_data.py | 2 +- .../convert_format.py | 3 +- .../convert_to_data.py | 2 +- .../{tools => skythought_evals}/eval.py | 0 .../inference_and_check.py | 3 +- .../label_math_difficulty.py | 3 +- .../labeled_numina_difficulty/README.md | 0 .../requirements.txt | 0 .../response_rewrite.py | 9 ++-- .../tasks/__init__.py | 6 +-- .../tasks/aime/aime.yaml | 0 .../tasks/aime/aime_handler.py | 3 +- .../tasks/amc23/amc23.yaml | 5 +- .../tasks/amc23/amc23_handler.py | 0 .../tasks/apps/apps.yaml | 3 +- .../tasks/apps/apps_handler.py | 2 +- .../tasks/apps/apps_util.py | 53 +++++++++++-------- .../tasks/arc/arc_c.yaml | 0 .../tasks/arc/arc_handler.py | 3 +- .../{tools => skythought_evals}/tasks/base.py | 0 .../tasks/gpqa_diamond/gpqa_diamond.yaml | 0 .../gpqa_diamond/gpqa_diamond_handler.py | 3 +- .../tasks/gsm8k/gsm8k.yaml | 0 .../tasks/gsm8k/gsm8k_handler.py | 3 +- .../tasks/livecodebench/livecodebench.yaml | 0 .../livecodebench/livecodebench_handler.py | 2 +- .../tasks/livecodebench/livecodebench_util.py | 22 +++++--- .../tasks/math/math500.yaml | 0 .../tasks/math/math_handler.py | 7 ++- .../tasks/minervamath/minervamath.yaml | 0 .../tasks/minervamath/minervamath_handler.py | 7 ++- .../tasks/mmlu/mmlu.yaml | 0 .../tasks/mmlu/mmlu_handler.py | 3 +- .../tasks/mmlu/mmlu_pro.yaml | 0 .../tasks/numina/numina.yaml | 4 +- .../tasks/numina/numina_handler.py | 8 ++- .../olympiadbench/olympiadbench_handler.py | 7 ++- .../olympiadbench/olympiadbench_math_en.yaml | 0 .../tasks/taco/pyext2.py | 18 ++++--- .../tasks/taco/taco.yaml | 4 +- .../tasks/taco/taco_handler.py | 2 +- .../tasks/taco/taco_util.py | 28 +++++----- .../tasks/task_util.py | 14 ++--- .../{tools => skythought_evals}/upload_hub.py | 0 skythought/skythought_evals/util/__init__.py | 0 .../util/common.py | 0 .../util/math_parsing_util.py | 28 +++++----- .../util/model_utils.py | 0 .../util/prompts.py | 8 +-- 58 files changed, 170 insertions(+), 134 deletions(-) delete mode 100755 .githooks/pre-commit rename skythought/{tools => skythought_evals}/.gitattributes (100%) rename skythought/{tools => skythought_evals}/README.md (100%) rename skythought/{tools/util => skythought_evals}/__init__.py (100%) rename skythought/{tools => skythought_evals}/base_instruct_evals.md (100%) rename skythought/{tools => skythought_evals}/combine_data.py (96%) rename skythought/{tools => skythought_evals}/convert_format.py (97%) rename skythought/{tools => skythought_evals}/convert_to_data.py (97%) rename skythought/{tools => skythought_evals}/eval.py (100%) rename skythought/{tools => skythought_evals}/inference_and_check.py (99%) rename skythought/{tools => skythought_evals}/label_math_difficulty.py (98%) rename skythought/{tools => skythought_evals}/labeled_numina_difficulty/README.md (100%) rename skythought/{tools => skythought_evals}/requirements.txt (100%) rename skythought/{tools => skythought_evals}/response_rewrite.py (99%) rename skythought/{tools => skythought_evals}/tasks/__init__.py (97%) rename skythought/{tools => skythought_evals}/tasks/aime/aime.yaml (100%) rename skythought/{tools => skythought_evals}/tasks/aime/aime_handler.py (95%) rename skythought/{tools => skythought_evals}/tasks/amc23/amc23.yaml (84%) rename skythought/{tools => skythought_evals}/tasks/amc23/amc23_handler.py (100%) rename skythought/{tools => skythought_evals}/tasks/apps/apps.yaml (94%) rename skythought/{tools => skythought_evals}/tasks/apps/apps_handler.py (98%) rename skythought/{tools => skythought_evals}/tasks/apps/apps_util.py (93%) rename skythought/{tools => skythought_evals}/tasks/arc/arc_c.yaml (100%) rename skythought/{tools => skythought_evals}/tasks/arc/arc_handler.py (98%) rename skythought/{tools => skythought_evals}/tasks/base.py (100%) rename skythought/{tools => skythought_evals}/tasks/gpqa_diamond/gpqa_diamond.yaml (100%) rename skythought/{tools => skythought_evals}/tasks/gpqa_diamond/gpqa_diamond_handler.py (97%) rename skythought/{tools => skythought_evals}/tasks/gsm8k/gsm8k.yaml (100%) rename skythought/{tools => skythought_evals}/tasks/gsm8k/gsm8k_handler.py (97%) rename skythought/{tools => skythought_evals}/tasks/livecodebench/livecodebench.yaml (100%) rename skythought/{tools => skythought_evals}/tasks/livecodebench/livecodebench_handler.py (98%) rename skythought/{tools => skythought_evals}/tasks/livecodebench/livecodebench_util.py (96%) rename skythought/{tools => skythought_evals}/tasks/math/math500.yaml (100%) rename skythought/{tools => skythought_evals}/tasks/math/math_handler.py (94%) rename skythought/{tools => skythought_evals}/tasks/minervamath/minervamath.yaml (100%) rename skythought/{tools => skythought_evals}/tasks/minervamath/minervamath_handler.py (76%) rename skythought/{tools => skythought_evals}/tasks/mmlu/mmlu.yaml (100%) rename skythought/{tools => skythought_evals}/tasks/mmlu/mmlu_handler.py (98%) rename skythought/{tools => skythought_evals}/tasks/mmlu/mmlu_pro.yaml (100%) rename skythought/{tools => skythought_evals}/tasks/numina/numina.yaml (85%) rename skythought/{tools => skythought_evals}/tasks/numina/numina_handler.py (94%) rename skythought/{tools => skythought_evals}/tasks/olympiadbench/olympiadbench_handler.py (77%) rename skythought/{tools => skythought_evals}/tasks/olympiadbench/olympiadbench_math_en.yaml (100%) rename skythought/{tools => skythought_evals}/tasks/taco/pyext2.py (96%) rename skythought/{tools => skythought_evals}/tasks/taco/taco.yaml (93%) rename skythought/{tools => skythought_evals}/tasks/taco/taco_handler.py (98%) rename skythought/{tools => skythought_evals}/tasks/taco/taco_util.py (98%) rename skythought/{tools => skythought_evals}/tasks/task_util.py (69%) rename skythought/{tools => skythought_evals}/upload_hub.py (100%) create mode 100644 skythought/skythought_evals/util/__init__.py rename skythought/{tools => skythought_evals}/util/common.py (100%) rename skythought/{tools => skythought_evals}/util/math_parsing_util.py (98%) rename skythought/{tools => skythought_evals}/util/model_utils.py (100%) rename skythought/{tools => skythought_evals}/util/prompts.py (98%) diff --git a/.githooks/pre-commit b/.githooks/pre-commit deleted file mode 100755 index cd2a157..0000000 --- a/.githooks/pre-commit +++ /dev/null @@ -1,14 +0,0 @@ -set -e - -GIT_ROOT=$(git rev-parse --show-toplevel) - -tools_relative=skythought/tools -# Only run pre-commit if changes are in tools/ -# Run pre-commit from tools/ directory to use linting rules in this directory -tools_files=$(git diff --cached --name-only --relative="./" -- $tools_relative) - -if [ -n $tools_files ]; then - cd $TOOLS_DIR; - # only get the diffs in tools/ and run pre-commit - pre-commit run --files $tools_files --config .pre-commit-config.yaml -fi diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f978811..2f233aa 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,8 @@ repos: hooks: - id: ruff args: [ --fix, --exit-non-zero-on-fix ] - exclude: ^skythought/train + exclude: (^skythought/train|skythought_evals/tasks/taco/pyext2\.py|skythought_evals/tasks/taco/taco_util\.py|skythought_evals/tasks/apps/apps_util\.py|skythought_evals/util/prompts\.py|skythought_evals/util/model_utils\.py)$ + # Black needs to be ran after ruff with --fix - repo: https://github.com/psf/black diff --git a/format.sh b/format.sh index 0c30a99..afdeb4a 100644 --- a/format.sh +++ b/format.sh @@ -12,9 +12,5 @@ else pip install -q pre-commit fi -# Hook file should be executable -HOOK_SCRIPT=$GIT_ROOT/.githooks/pre-commit -chmod +x $HOOK_SCRIPT - # pre-commit run --all-files always runs from the root directory. we run this only on tools/ for now. -pre-commit run --files $GIT_ROOT/skythought/tools --config .pre-commit-config.yaml +pre-commit run --all-files --config .pre-commit-config.yaml diff --git a/pyproject.toml b/pyproject.toml index 58eac71..123d375 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,4 +2,5 @@ line-length = 160 [tool.ruff.lint] -extend-select = ["E", "F", "I", "ASYNC", "B"] \ No newline at end of file +extend-select = ["E", "F", "I", "ASYNC", "B"] +ignore = ["F811", "B006"] \ No newline at end of file diff --git a/setup.py b/setup.py index 5d1c372..3a484f2 100644 --- a/setup.py +++ b/setup.py @@ -2,9 +2,14 @@ import setuptools setuptools.setup( - name="skyevals", + name="skythought_evals", version="0.0.1", - package_dir={"skyevals": "skythought/tools"}, # map skyevals to skythought/tools - packages=["skyevals"] - + [f"skyevals.{pkg}" for pkg in setuptools.find_packages(where="skythought/tools")], + package_dir={ + "skythought_evals": "skythought/skythought_evals" + }, # map skythought_evals to skythought/skythought_evals + packages=["skythought_evals"] + + [ + f"skythought_evals.{pkg}" + for pkg in setuptools.find_packages(where="skythought/skythought_evals") + ], ) diff --git a/skythought/tools/.gitattributes b/skythought/skythought_evals/.gitattributes similarity index 100% rename from skythought/tools/.gitattributes rename to skythought/skythought_evals/.gitattributes diff --git a/skythought/tools/README.md b/skythought/skythought_evals/README.md similarity index 100% rename from skythought/tools/README.md rename to skythought/skythought_evals/README.md diff --git a/skythought/tools/util/__init__.py b/skythought/skythought_evals/__init__.py similarity index 100% rename from skythought/tools/util/__init__.py rename to skythought/skythought_evals/__init__.py diff --git a/skythought/tools/base_instruct_evals.md b/skythought/skythought_evals/base_instruct_evals.md similarity index 100% rename from skythought/tools/base_instruct_evals.md rename to skythought/skythought_evals/base_instruct_evals.md diff --git a/skythought/tools/combine_data.py b/skythought/skythought_evals/combine_data.py similarity index 96% rename from skythought/tools/combine_data.py rename to skythought/skythought_evals/combine_data.py index 8106501..8178442 100644 --- a/skythought/tools/combine_data.py +++ b/skythought/skythought_evals/combine_data.py @@ -1,7 +1,7 @@ import json import random -from .util.prompts import system_prompt +from skythought_evals.util.prompts import system_prompt still2_jsonl_file = "../../data/public_long_form_thought_data_5k.jsonl" code_json_file = "../../data/converted_apps_long_form_thought_data_5k.json" diff --git a/skythought/tools/convert_format.py b/skythought/skythought_evals/convert_format.py similarity index 97% rename from skythought/tools/convert_format.py rename to skythought/skythought_evals/convert_format.py index 17fe9f3..d171989 100644 --- a/skythought/tools/convert_format.py +++ b/skythought/skythought_evals/convert_format.py @@ -6,10 +6,9 @@ from itertools import cycle import openai +from skythought_evals.util.prompts import convert_prompt, convert_prompt_example from tqdm import tqdm -from .util.prompts import convert_prompt, convert_prompt_example - global args diff --git a/skythought/tools/convert_to_data.py b/skythought/skythought_evals/convert_to_data.py similarity index 97% rename from skythought/tools/convert_to_data.py rename to skythought/skythought_evals/convert_to_data.py index 890c4f4..b1fbe2d 100644 --- a/skythought/tools/convert_to_data.py +++ b/skythought/skythought_evals/convert_to_data.py @@ -2,7 +2,7 @@ import json import os -from .util.prompts import system_prompt +from skythought_evals.util.prompts import system_prompt def main(): diff --git a/skythought/tools/eval.py b/skythought/skythought_evals/eval.py similarity index 100% rename from skythought/tools/eval.py rename to skythought/skythought_evals/eval.py diff --git a/skythought/tools/inference_and_check.py b/skythought/skythought_evals/inference_and_check.py similarity index 99% rename from skythought/tools/inference_and_check.py rename to skythought/skythought_evals/inference_and_check.py index 7a7d5f1..df60734 100644 --- a/skythought/tools/inference_and_check.py +++ b/skythought/skythought_evals/inference_and_check.py @@ -7,13 +7,12 @@ import numpy as np from openai import OpenAI +from skythought_evals.util.model_utils import MODEL_TO_NAME, SYSTEM_PROMPT from tasks import TASK_HANDLER_MAP, NUMINATaskHandler, TaskHandler from tasks.task_util import get_tasks from tqdm import tqdm from vllm import LLM, SamplingParams -from .util.model_utils import MODEL_TO_NAME, SYSTEM_PROMPT - module_dir = os.path.dirname(os.path.abspath(__file__)) TASK_NAMES_TO_YAML = get_tasks(os.path.join(module_dir, "tasks")) diff --git a/skythought/tools/label_math_difficulty.py b/skythought/skythought_evals/label_math_difficulty.py similarity index 98% rename from skythought/tools/label_math_difficulty.py rename to skythought/skythought_evals/label_math_difficulty.py index d810a83..1c18611 100644 --- a/skythought/tools/label_math_difficulty.py +++ b/skythought/skythought_evals/label_math_difficulty.py @@ -9,10 +9,9 @@ import openai from datasets import load_dataset +from skythought_evals.util.prompts import aops_criteria, grading_prompt from tqdm import tqdm -from .util.prompts import aops_criteria, grading_prompt - # Function to set the OpenAI API key def set_openai_key(api_key): diff --git a/skythought/tools/labeled_numina_difficulty/README.md b/skythought/skythought_evals/labeled_numina_difficulty/README.md similarity index 100% rename from skythought/tools/labeled_numina_difficulty/README.md rename to skythought/skythought_evals/labeled_numina_difficulty/README.md diff --git a/skythought/tools/requirements.txt b/skythought/skythought_evals/requirements.txt similarity index 100% rename from skythought/tools/requirements.txt rename to skythought/skythought_evals/requirements.txt diff --git a/skythought/tools/response_rewrite.py b/skythought/skythought_evals/response_rewrite.py similarity index 99% rename from skythought/tools/response_rewrite.py rename to skythought/skythought_evals/response_rewrite.py index 04e888f..a08ead7 100644 --- a/skythought/tools/response_rewrite.py +++ b/skythought/skythought_evals/response_rewrite.py @@ -3,15 +3,14 @@ import os import random -from tqdm import tqdm -from vllm import LLM, SamplingParams - -from .util.math_parsing_util import strip_answer_string -from .util.model_utils import ( +from skythought_evals.util.math_parsing_util import strip_answer_string +from skythought_evals.util.model_utils import ( SUBPROBLEM_SPLIT_PROMPT, SUBSOLUTION_EXTRACTION_PROMPT, SYSTEM_PROMPT, ) +from tqdm import tqdm +from vllm import LLM, SamplingParams def load_dataset(dataset_path: str): diff --git a/skythought/tools/tasks/__init__.py b/skythought/skythought_evals/tasks/__init__.py similarity index 97% rename from skythought/tools/tasks/__init__.py rename to skythought/skythought_evals/tasks/__init__.py index 97cc533..a9d2882 100644 --- a/skythought/tools/tasks/__init__.py +++ b/skythought/skythought_evals/tasks/__init__.py @@ -2,16 +2,16 @@ from .amc23.amc23_handler import AMC23TaskHandler from .apps.apps_handler import APPSTaskHandler from .arc.arc_handler import ARCChallengeTaskHandler -from .base import TaskHandler, TaskConfig +from .base import TaskConfig, TaskHandler from .gpqa_diamond.gpqa_diamond_handler import GPQADiamondTaskHandler from .gsm8k.gsm8k_handler import GSM8KTaskHandler from .livecodebench.livecodebench_handler import LiveCodeBenchTaskHandler from .math.math_handler import MathTaskHandler +from .minervamath.minervamath_handler import MinervaMathTaskHandler from .mmlu.mmlu_handler import MMLUProTaskHandler, MMLUTaskHandler from .numina.numina_handler import NUMINATaskHandler -from .taco.taco_handler import TACOTaskHandler -from .minervamath.minervamath_handler import MinervaMathTaskHandler from .olympiadbench.olympiadbench_handler import OlympiadBenchMathTaskHandler +from .taco.taco_handler import TACOTaskHandler TASK_HANDLER_MAP = { "numina": NUMINATaskHandler, diff --git a/skythought/tools/tasks/aime/aime.yaml b/skythought/skythought_evals/tasks/aime/aime.yaml similarity index 100% rename from skythought/tools/tasks/aime/aime.yaml rename to skythought/skythought_evals/tasks/aime/aime.yaml diff --git a/skythought/tools/tasks/aime/aime_handler.py b/skythought/skythought_evals/tasks/aime/aime_handler.py similarity index 95% rename from skythought/tools/tasks/aime/aime_handler.py rename to skythought/skythought_evals/tasks/aime/aime_handler.py index afe7ea6..0fd8b43 100644 --- a/skythought/tools/tasks/aime/aime_handler.py +++ b/skythought/skythought_evals/tasks/aime/aime_handler.py @@ -1,6 +1,7 @@ from typing import Dict -from ...util.model_utils import MODEL_TO_NAME +from skythought_evals.util.model_utils import MODEL_TO_NAME + from ..math.math_handler import MathTaskHandler diff --git a/skythought/tools/tasks/amc23/amc23.yaml b/skythought/skythought_evals/tasks/amc23/amc23.yaml similarity index 84% rename from skythought/tools/tasks/amc23/amc23.yaml rename to skythought/skythought_evals/tasks/amc23/amc23.yaml index c86ed2a..a7ece4d 100644 --- a/skythought/tools/tasks/amc23/amc23.yaml +++ b/skythought/skythought_evals/tasks/amc23/amc23.yaml @@ -5,6 +5,7 @@ dataset_kwargs: dataset_split: train question_key: problem answer_key: answer -difficulty: null +preprocess_config: + difficulty: null templating_parameters: - template: "Return your final response within \\boxed{{}}. {problem}" \ No newline at end of file + template: "Return your final response within \\boxed{{}}. {problem}" diff --git a/skythought/tools/tasks/amc23/amc23_handler.py b/skythought/skythought_evals/tasks/amc23/amc23_handler.py similarity index 100% rename from skythought/tools/tasks/amc23/amc23_handler.py rename to skythought/skythought_evals/tasks/amc23/amc23_handler.py diff --git a/skythought/tools/tasks/apps/apps.yaml b/skythought/skythought_evals/tasks/apps/apps.yaml similarity index 94% rename from skythought/tools/tasks/apps/apps.yaml rename to skythought/skythought_evals/tasks/apps/apps.yaml index fa5f38c..455d027 100644 --- a/skythought/tools/tasks/apps/apps.yaml +++ b/skythought/skythought_evals/tasks/apps/apps.yaml @@ -5,7 +5,8 @@ dataset_kwargs: dataset_split: train question_key: question answer_key: null -difficulty: null +# preprocess_config: +# difficulty: null templating_parameters: with_fn_name_template: "Generate an executable Python function generated from the given prompt. The function should take stdin as input and print the output. Simply call the function after the definition. {prompt}" without_fn_name_template: "Generate an executable Python function generated from the given prompt. Return the function body without invoking it at the final solution. {prompt}" diff --git a/skythought/tools/tasks/apps/apps_handler.py b/skythought/skythought_evals/tasks/apps/apps_handler.py similarity index 98% rename from skythought/tools/tasks/apps/apps_handler.py rename to skythought/skythought_evals/tasks/apps/apps_handler.py index ad56e01..d9b5b20 100644 --- a/skythought/tools/tasks/apps/apps_handler.py +++ b/skythought/skythought_evals/tasks/apps/apps_handler.py @@ -4,8 +4,8 @@ from multiprocessing import Manager import numpy as np +from skythought_evals.util.common import has_code -from ...util.common import has_code from ..apps.apps_util import run_test as apps_run_test from ..base import TaskHandler diff --git a/skythought/tools/tasks/apps/apps_util.py b/skythought/skythought_evals/tasks/apps/apps_util.py similarity index 93% rename from skythought/tools/tasks/apps/apps_util.py rename to skythought/skythought_evals/tasks/apps/apps_util.py index c98dd1f..731f811 100644 --- a/skythought/tools/tasks/apps/apps_util.py +++ b/skythought/skythought_evals/tasks/apps/apps_util.py @@ -157,7 +157,7 @@ def run_test( print(f"start = {datetime.now().time()}") if problem_list is not None: - root = problem_list[prob_index] + root = problem_list[prob_index] # noqa: F841 in_outs = problem["input_output"] if debug: @@ -179,7 +179,7 @@ def run_test( reliability_guard() results = [] - sol = "import sys\nimport time\nimport itertools\nfrom itertools import accumulate, product, permutations, combinations\nimport collections\nfrom collections import Counter, OrderedDict, deque, defaultdict, ChainMap\nfrom functools import lru_cache\nimport math\nfrom math import sqrt, sin, cos, tan, ceil, fabs, floor, gcd, exp, log, log2\nimport fractions\nfrom typing import List, Tuple\nimport numpy as np\nimport random\nimport heapq\nfrom heapq import *\n" + sol = "import sys\nimport time\nimport itertools\nfrom itertools import accumulate, product, permutations, combinations\nimport collections\nfrom collections import Counter, OrderedDict, deque, defaultdict, ChainMap\nfrom functools import lru_cache\nimport math\nfrom math import sqrt, sin, cos, tan, ceil, fabs, floor, gcd, exp, log, log2\nimport fractions\nfrom typing import List, Tuple\nimport numpy as np\nimport random\nimport heapq\nfrom heapq import *\n" # noqa: E501 if debug: print(f"loading test code = {datetime.now().time()}") @@ -249,7 +249,7 @@ def run_test( try: method = getattr(tmp, method_name) # get_attr second arg must be str - except: + except Exception: signal.alarm(0) e = sys.exc_info() print(f"unable to get function error = {e}") @@ -261,21 +261,21 @@ def run_test( try: if isinstance(inputs[0], dict): inputs = [{int(k): v for k, v in inputs[0].items()}] - except: + except Exception: True try: if isinstance(in_outs["outputs"][index], dict): in_outs["outputs"][index] = [ {int(k): v for k, v in in_outs["outputs"][index].items()} ] - except: + except Exception: True try: if isinstance(in_outs["outputs"][index][0], dict): in_outs["outputs"][index] = [ {int(k): v for k, v in in_outs["outputs"][index][0].items()} ] - except: + except Exception: True if debug: @@ -310,7 +310,7 @@ def run_test( [list(x) for x in output] == in_outs["outputs"][index][0] ) - except: + except Exception: True results.append(tmp_result) @@ -328,7 +328,8 @@ def run_test( signal.alarm(0) if debug: print( - f"outputs = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" + f"outputs = {output}, test outputs = {in_outs['outputs'][index]}," + f"inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" ) elif which_type == CODE_TYPE.standard_input: # Standard input faulthandler.enable() @@ -360,11 +361,14 @@ def run_test( nl = "\n" if not isinstance(inputs, list): print( - f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" + f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}," + f"inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}," + f" {output == [in_outs['outputs'][index]]}" ) else: print( - f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" + f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}," + f"inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" ) continue @@ -423,7 +427,7 @@ def run_test( print(f"Failed check2 exception = {e}") pass - if tmp_result == True: + if tmp_result: results.append(tmp_result) continue @@ -435,14 +439,17 @@ def run_test( nl = "\n" if not isinstance(inputs, list): print( - f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" + f"output = {output}, test outputs = {in_outs['outputs'][index]}," + f"inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}," + f" {output == [in_outs['outputs'][index]]}" ) else: print( - f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" + f"output = {output}, test outputs = {in_outs['outputs'][index]}," + f"inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" ) - if tmp_result == True: + if tmp_result: results.append(tmp_result) continue @@ -474,7 +481,7 @@ def run_test( except Exception: pass - if tmp_result == True: + if tmp_result: results.append(tmp_result) continue @@ -526,7 +533,7 @@ def run_test( except Exception as e: print(f"Failed check6 exception = {e}") - if tmp_result == True and debug: + if tmp_result and debug: print("PASSED") results.append(tmp_result) @@ -535,11 +542,14 @@ def run_test( nl = "\n" if not isinstance(inputs, list): print( - f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" + f"output = {output}, test outputs = {in_outs['outputs'][index]}," + f"inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}," + f" {output == [in_outs['outputs'][index]]}" ) else: print( - f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" + f"output = {output}, test outputs = {in_outs['outputs'][index]}," + f"inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" ) return results @@ -697,9 +707,10 @@ def main(args): ]: tmp = get_solutions(problem_list, prob_index) print("sol", tmp) - elif args.data == "starter": - tmp = get_starter(problem_list, prob_index) - print("starter", tmp) + # TODO: fix this by adding get_starter + # elif args.data == "starter": + # tmp = get_starter(problem_list, prob_index) + # print("starter", tmp) elif args.data in ["test", "t"]: # test it with sols sols = get_solutions(problem_list, prob_index) diff --git a/skythought/tools/tasks/arc/arc_c.yaml b/skythought/skythought_evals/tasks/arc/arc_c.yaml similarity index 100% rename from skythought/tools/tasks/arc/arc_c.yaml rename to skythought/skythought_evals/tasks/arc/arc_c.yaml diff --git a/skythought/tools/tasks/arc/arc_handler.py b/skythought/skythought_evals/tasks/arc/arc_handler.py similarity index 98% rename from skythought/tools/tasks/arc/arc_handler.py rename to skythought/skythought_evals/tasks/arc/arc_handler.py index 34c88f4..221c94d 100644 --- a/skythought/tools/tasks/arc/arc_handler.py +++ b/skythought/skythought_evals/tasks/arc/arc_handler.py @@ -1,7 +1,8 @@ import re from typing import Any, Dict -from ...util.math_parsing_util import extract_answer +from skythought_evals.util.math_parsing_util import extract_answer + from ..base import TaskConfig, TaskHandler diff --git a/skythought/tools/tasks/base.py b/skythought/skythought_evals/tasks/base.py similarity index 100% rename from skythought/tools/tasks/base.py rename to skythought/skythought_evals/tasks/base.py diff --git a/skythought/tools/tasks/gpqa_diamond/gpqa_diamond.yaml b/skythought/skythought_evals/tasks/gpqa_diamond/gpqa_diamond.yaml similarity index 100% rename from skythought/tools/tasks/gpqa_diamond/gpqa_diamond.yaml rename to skythought/skythought_evals/tasks/gpqa_diamond/gpqa_diamond.yaml diff --git a/skythought/tools/tasks/gpqa_diamond/gpqa_diamond_handler.py b/skythought/skythought_evals/tasks/gpqa_diamond/gpqa_diamond_handler.py similarity index 97% rename from skythought/tools/tasks/gpqa_diamond/gpqa_diamond_handler.py rename to skythought/skythought_evals/tasks/gpqa_diamond/gpqa_diamond_handler.py index bff304b..fb0d5ef 100644 --- a/skythought/tools/tasks/gpqa_diamond/gpqa_diamond_handler.py +++ b/skythought/skythought_evals/tasks/gpqa_diamond/gpqa_diamond_handler.py @@ -1,6 +1,7 @@ import random -from ...util.math_parsing_util import get_multiple_choice_answer +from skythought_evals.util.math_parsing_util import get_multiple_choice_answer + from ..base import TaskHandler diff --git a/skythought/tools/tasks/gsm8k/gsm8k.yaml b/skythought/skythought_evals/tasks/gsm8k/gsm8k.yaml similarity index 100% rename from skythought/tools/tasks/gsm8k/gsm8k.yaml rename to skythought/skythought_evals/tasks/gsm8k/gsm8k.yaml diff --git a/skythought/tools/tasks/gsm8k/gsm8k_handler.py b/skythought/skythought_evals/tasks/gsm8k/gsm8k_handler.py similarity index 97% rename from skythought/tools/tasks/gsm8k/gsm8k_handler.py rename to skythought/skythought_evals/tasks/gsm8k/gsm8k_handler.py index 364785b..f913b51 100644 --- a/skythought/tools/tasks/gsm8k/gsm8k_handler.py +++ b/skythought/skythought_evals/tasks/gsm8k/gsm8k_handler.py @@ -1,7 +1,8 @@ import re from typing import Any, Dict -from ...util.math_parsing_util import extract_answer +from skythought_evals.util.math_parsing_util import extract_answer + from ..base import TaskConfig, TaskHandler diff --git a/skythought/tools/tasks/livecodebench/livecodebench.yaml b/skythought/skythought_evals/tasks/livecodebench/livecodebench.yaml similarity index 100% rename from skythought/tools/tasks/livecodebench/livecodebench.yaml rename to skythought/skythought_evals/tasks/livecodebench/livecodebench.yaml diff --git a/skythought/tools/tasks/livecodebench/livecodebench_handler.py b/skythought/skythought_evals/tasks/livecodebench/livecodebench_handler.py similarity index 98% rename from skythought/tools/tasks/livecodebench/livecodebench_handler.py rename to skythought/skythought_evals/tasks/livecodebench/livecodebench_handler.py index 807147b..5e50ff3 100644 --- a/skythought/tools/tasks/livecodebench/livecodebench_handler.py +++ b/skythought/skythought_evals/tasks/livecodebench/livecodebench_handler.py @@ -2,8 +2,8 @@ from typing import Dict from datasets import Dataset as HFDataset +from skythought_evals.util.common import has_code -from ...util.common import has_code from ..base import TaskHandler from .livecodebench_util import ( map_to_example, diff --git a/skythought/tools/tasks/livecodebench/livecodebench_util.py b/skythought/skythought_evals/tasks/livecodebench/livecodebench_util.py similarity index 96% rename from skythought/tools/tasks/livecodebench/livecodebench_util.py rename to skythought/skythought_evals/tasks/livecodebench/livecodebench_util.py index bc00fb6..947e3fe 100644 --- a/skythought/tools/tasks/livecodebench/livecodebench_util.py +++ b/skythought/skythought_evals/tasks/livecodebench/livecodebench_util.py @@ -8,6 +8,7 @@ import multiprocessing import os import pickle +import re import shutil import signal import subprocess @@ -56,7 +57,12 @@ def post_process_tests_inputs(raw_text, is_stdin): return formatted_tests else: # Step 1: Clean the input string by removing surrounding markdown syntax and extra spaces - cleaned_string = raw_text.strip().strip("```json").strip("```").strip() + # TODO: .strip() with multiple characters is misleading, this will look for the individual characters in any order. + # we should switch to .replace and test + # TODO: Do not ignore B005 + cleaned_string = ( + raw_text.strip().strip("```json").strip("```").strip() # noqa: B005 + ) # Step 2: Check if it's a JSON array if cleaned_string.startswith("[") and cleaned_string.endswith("]"): @@ -77,16 +83,16 @@ def post_process_tests_inputs(raw_text, is_stdin): try: test_cases = json.loads(json_array_string) for test_case in test_cases: - test_case[ - "testtype" - ] = "functional" # Add 'testtype' for each test case + test_case["testtype"] = ( + "functional" # Add 'testtype' for each test case + ) return test_cases except json.JSONDecodeError as e: print(f"Error parsing concatenated JSON: {e}") # If no matches are found, fall back to line-by-line parsing cleaned_lines = cleaned_string.split("\n") - if test_cases == None: + if test_cases is None: test_cases = [] for line in cleaned_lines: try: @@ -296,7 +302,7 @@ def unsafe_lcb_runTests(problem, completion, timeout, runtime_debug, is_extracte p.kill() # if len(result) < len(test_cases): ## This is supposed to be the case where not all test passed in the given timeout - for i in range(len(test_cases) - len(result)): + for _i in range(len(test_cases) - len(result)): result.append((False, "Time out!.", "Error: Time out!", float("inf"))) return result @@ -307,7 +313,7 @@ def run_tests_for_one_example( time_elapsed = float("inf") test_type = test_cases[0]["testtype"] reliability_guard() - for i, test_case in enumerate(test_cases): + for _i, test_case in enumerate(test_cases): output_error = "" output_value = "" try: @@ -339,7 +345,7 @@ def run_tests_for_one_example( output_error = f"For test input: {test_input}. Expected output is: {test_output}, but got error: {e}." output_value = f"Error: {e}." if output_error == "": - output_error = f"For test input: {test_input}. Expected output is: {test_output}, your solution correctly passes this test with output {output_value}." + output_error = f"For test input: {test_input}. Expected output is: {test_output}, your solution correctly passes this test with output {output_value}." # noqa: E501 result_list.append((passed, output_error, output_value, time_elapsed)) if not passed: diff --git a/skythought/tools/tasks/math/math500.yaml b/skythought/skythought_evals/tasks/math/math500.yaml similarity index 100% rename from skythought/tools/tasks/math/math500.yaml rename to skythought/skythought_evals/tasks/math/math500.yaml diff --git a/skythought/tools/tasks/math/math_handler.py b/skythought/skythought_evals/tasks/math/math_handler.py similarity index 94% rename from skythought/tools/tasks/math/math_handler.py rename to skythought/skythought_evals/tasks/math/math_handler.py index adc7949..ca948b2 100644 --- a/skythought/tools/tasks/math/math_handler.py +++ b/skythought/skythought_evals/tasks/math/math_handler.py @@ -1,4 +1,9 @@ -from ...util.math_parsing_util import extract_answer, math_equal, strip_answer_string +from skythought_evals.util.math_parsing_util import ( + extract_answer, + math_equal, + strip_answer_string, +) + from ..base import TaskHandler diff --git a/skythought/tools/tasks/minervamath/minervamath.yaml b/skythought/skythought_evals/tasks/minervamath/minervamath.yaml similarity index 100% rename from skythought/tools/tasks/minervamath/minervamath.yaml rename to skythought/skythought_evals/tasks/minervamath/minervamath.yaml diff --git a/skythought/tools/tasks/minervamath/minervamath_handler.py b/skythought/skythought_evals/tasks/minervamath/minervamath_handler.py similarity index 76% rename from skythought/tools/tasks/minervamath/minervamath_handler.py rename to skythought/skythought_evals/tasks/minervamath/minervamath_handler.py index b82d014..ef13461 100644 --- a/skythought/tools/tasks/minervamath/minervamath_handler.py +++ b/skythought/skythought_evals/tasks/minervamath/minervamath_handler.py @@ -1,4 +1,9 @@ -from ...util.math_parsing_util import extract_answer, math_equal, strip_answer_string +from skythought_evals.util.math_parsing_util import ( + extract_answer, + math_equal, + strip_answer_string, +) + from ..math.math_handler import MathTaskHandler diff --git a/skythought/tools/tasks/mmlu/mmlu.yaml b/skythought/skythought_evals/tasks/mmlu/mmlu.yaml similarity index 100% rename from skythought/tools/tasks/mmlu/mmlu.yaml rename to skythought/skythought_evals/tasks/mmlu/mmlu.yaml diff --git a/skythought/tools/tasks/mmlu/mmlu_handler.py b/skythought/skythought_evals/tasks/mmlu/mmlu_handler.py similarity index 98% rename from skythought/tools/tasks/mmlu/mmlu_handler.py rename to skythought/skythought_evals/tasks/mmlu/mmlu_handler.py index f11a2b6..b69174c 100644 --- a/skythought/tools/tasks/mmlu/mmlu_handler.py +++ b/skythought/skythought_evals/tasks/mmlu/mmlu_handler.py @@ -1,7 +1,8 @@ -from ...util.math_parsing_util import ( +from skythought_evals.util.math_parsing_util import ( get_multiple_choice_answer, mmlu_pro_extract_answer, ) + from ..base import TaskConfig, TaskHandler diff --git a/skythought/tools/tasks/mmlu/mmlu_pro.yaml b/skythought/skythought_evals/tasks/mmlu/mmlu_pro.yaml similarity index 100% rename from skythought/tools/tasks/mmlu/mmlu_pro.yaml rename to skythought/skythought_evals/tasks/mmlu/mmlu_pro.yaml diff --git a/skythought/tools/tasks/numina/numina.yaml b/skythought/skythought_evals/tasks/numina/numina.yaml similarity index 85% rename from skythought/tools/tasks/numina/numina.yaml rename to skythought/skythought_evals/tasks/numina/numina.yaml index 1a0f72d..b8cbdff 100644 --- a/skythought/tools/tasks/numina/numina.yaml +++ b/skythought/skythought_evals/tasks/numina/numina.yaml @@ -6,5 +6,5 @@ question_key: problem answer_key: solution templating_parameters: template: "Return your final response within \\boxed{{}}. {prompt}" -preprocess_config: - difficulty: null \ No newline at end of file +# preprocess_config: +# difficulty: null \ No newline at end of file diff --git a/skythought/tools/tasks/numina/numina_handler.py b/skythought/skythought_evals/tasks/numina/numina_handler.py similarity index 94% rename from skythought/tools/tasks/numina/numina_handler.py rename to skythought/skythought_evals/tasks/numina/numina_handler.py index dad0fc1..f2842e0 100644 --- a/skythought/tools/tasks/numina/numina_handler.py +++ b/skythought/skythought_evals/tasks/numina/numina_handler.py @@ -1,7 +1,11 @@ from datasets import load_dataset +from skythought_evals.util.common import TimeoutException, timeout +from skythought_evals.util.math_parsing_util import ( + extract_answer, + math_equal, + strip_answer_string, +) -from ...util.common import TimeoutException, timeout -from ...util.math_parsing_util import extract_answer, math_equal, strip_answer_string from ..base import TaskHandler diff --git a/skythought/tools/tasks/olympiadbench/olympiadbench_handler.py b/skythought/skythought_evals/tasks/olympiadbench/olympiadbench_handler.py similarity index 77% rename from skythought/tools/tasks/olympiadbench/olympiadbench_handler.py rename to skythought/skythought_evals/tasks/olympiadbench/olympiadbench_handler.py index 6807bbd..4264f56 100644 --- a/skythought/tools/tasks/olympiadbench/olympiadbench_handler.py +++ b/skythought/skythought_evals/tasks/olympiadbench/olympiadbench_handler.py @@ -1,4 +1,9 @@ -from ...util.math_parsing_util import extract_answer, math_equal, strip_answer_string +from skythought_evals.util.math_parsing_util import ( + extract_answer, + math_equal, + strip_answer_string, +) + from ..math.math_handler import MathTaskHandler diff --git a/skythought/tools/tasks/olympiadbench/olympiadbench_math_en.yaml b/skythought/skythought_evals/tasks/olympiadbench/olympiadbench_math_en.yaml similarity index 100% rename from skythought/tools/tasks/olympiadbench/olympiadbench_math_en.yaml rename to skythought/skythought_evals/tasks/olympiadbench/olympiadbench_math_en.yaml diff --git a/skythought/tools/tasks/taco/pyext2.py b/skythought/skythought_evals/tasks/taco/pyext2.py similarity index 96% rename from skythought/tools/tasks/taco/pyext2.py rename to skythought/skythought_evals/tasks/taco/pyext2.py index 1bff5e4..514cf2b 100644 --- a/skythought/tools/tasks/taco/pyext2.py +++ b/skythought/skythought_evals/tasks/taco/pyext2.py @@ -42,9 +42,9 @@ "run_main", ] -import inspect -import sys -import types +import inspect # noqa: E402 +import sys # noqa: E402 +import types # noqa: E402 def __targspec(func, specs, attr="__orig_arg__"): @@ -379,7 +379,9 @@ def from_string(module_name_for_code_eval, docstring, s): class CaseObject(object): - "The object returned by a switch statement. When called, it will return True if the given argument equals its value, else False. It can be called with multiple parameters, in which case it checks if its value equals any of the arguments." + """The object returned by a switch statement. When called, it will return True if the given argument equals its value, else False. + It can be called with multiple parameters, in which case it checks if its value equals any of the arguments. + """ def __init__(self, value): self.value = value @@ -391,7 +393,7 @@ def __call__(self, *args): "res", not self.did_pass and any([self.value == rhs for rhs in args]) ): self.did_match = True - return res + return res # noqa: F821 def quit(self): "Forces all other calls to return False. Equilavent of a ``break`` statement." @@ -432,7 +434,8 @@ def switch(value): def tail_recurse(spec=None): """Remove tail recursion from a function. - :param spec: A function that, when given the arguments, returns a bool indicating whether or not to exit. If ``None,`` tail recursion is always called unless the function returns a value. + :param spec: A function that, when given the arguments, returns a bool indicating whether or not to exit. + If ``None``, tail recursion is always called unless the function returns a value. .. note:: @@ -527,7 +530,8 @@ def _wrap(f): def safe_unpack(seq, ln, fill=None): - """Safely unpack a sequence to length `ln`, without raising ValueError. Based on Lua's method of unpacking. Empty values will be filled in with `fill`, while any extra values will be cut off. + """Safely unpack a sequence to length `ln`, without raising ValueError. Based on Lua's method of unpacking. + Empty values will be filled in with `fill`, while any extra values will be cut off. :param seq: The sequence to unpack. diff --git a/skythought/tools/tasks/taco/taco.yaml b/skythought/skythought_evals/tasks/taco/taco.yaml similarity index 93% rename from skythought/tools/tasks/taco/taco.yaml rename to skythought/skythought_evals/tasks/taco/taco.yaml index c0521cb..85d317b 100644 --- a/skythought/tools/tasks/taco/taco.yaml +++ b/skythought/skythought_evals/tasks/taco/taco.yaml @@ -14,6 +14,6 @@ templating_parameters: stdin_template: "{input}\nUse Standard Input format\nANSWER:\n" # call template is used when there is starter code or fn_name call_template: "{input}\nUse Call-Based format\nANSWER:\n" -preprocess_config: - difficulty: null +# preprocess_config: +# difficulty: null diff --git a/skythought/tools/tasks/taco/taco_handler.py b/skythought/skythought_evals/tasks/taco/taco_handler.py similarity index 98% rename from skythought/tools/tasks/taco/taco_handler.py rename to skythought/skythought_evals/tasks/taco/taco_handler.py index 4a615dc..48b36a0 100644 --- a/skythought/tools/tasks/taco/taco_handler.py +++ b/skythought/skythought_evals/tasks/taco/taco_handler.py @@ -3,8 +3,8 @@ from multiprocessing import Manager import numpy as np +from skythought_evals.util.common import has_code -from ...util.common import has_code from ..base import TaskHandler from .taco_util import run_test as taco_run_test diff --git a/skythought/tools/tasks/taco/taco_util.py b/skythought/skythought_evals/tasks/taco/taco_util.py similarity index 98% rename from skythought/tools/tasks/taco/taco_util.py rename to skythought/skythought_evals/tasks/taco/taco_util.py index 68148b5..1941229 100644 --- a/skythought/tools/tasks/taco/taco_util.py +++ b/skythought/skythought_evals/tasks/taco/taco_util.py @@ -172,20 +172,20 @@ def process_input_output(inputs, outputs): try: if isinstance(inputs[0], dict): inputs = [{int(k): v for k, v in inputs[0].items()}] - except: - True + except Exception: + pass try: if isinstance(outputs, dict): outputs = [{int(k): v for k, v in outputs.items()}] - except: - True + except Exception: + pass try: if isinstance(outputs[0], dict): outputs = [{int(k): v for k, v in outputs[0].items()}] - except: - True + except Exception: + pass return inputs, outputs @@ -214,7 +214,7 @@ def compile_and_get_func(program, which_type, method_name, timeout, debug): signal.alarm(timeout) method = getattr(tmp, method_name) # get_attr second arg must be str signal.alarm(0) - except: + except Exception: signal.alarm(0) e = sys.exc_info() if debug: @@ -224,7 +224,7 @@ def compile_and_get_func(program, which_type, method_name, timeout, debug): def synthesize_cb_code(raw_code, debug=False): - sol = "import sys\nimport time\nimport itertools\nfrom itertools import accumulate, product, permutations, combinations\nimport collections\nfrom collections import Counter, OrderedDict, deque, defaultdict, ChainMap\nfrom functools import lru_cache\nimport math\nfrom math import sqrt, sin, cos, tan, ceil, fabs, floor, gcd, exp, log, log2\nimport fractions\nfrom typing import List, Tuple\nimport numpy as np\nimport random\nimport heapq\nfrom heapq import *\n" + sol = "import sys\nimport time\nimport itertools\nfrom itertools import accumulate, product, permutations, combinations\nimport collections\nfrom collections import Counter, OrderedDict, deque, defaultdict, ChainMap\nfrom functools import lru_cache\nimport math\nfrom math import sqrt, sin, cos, tan, ceil, fabs, floor, gcd, exp, log, log2\nimport fractions\nfrom typing import List, Tuple\nimport numpy as np\nimport random\nimport heapq\nfrom heapq import *\n" # noqa: E501 if debug: print(f"loading test code = {datetime.now().time()}") sol += raw_code @@ -232,7 +232,7 @@ def synthesize_cb_code(raw_code, debug=False): def synthesize_std_code(raw_code, debug=False): - normal_import_lines = "import sys\nimport time\nimport itertools\nfrom itertools import accumulate, product, permutations, combinations\nimport collections\nfrom collections import Counter, OrderedDict, deque, defaultdict, ChainMap\nfrom functools import lru_cache\nimport math\nfrom math import sqrt, sin, cos, tan, ceil, fabs, floor, gcd, exp, log, log2\nimport fractions\nfrom typing import List, Tuple\nimport numpy as np\nimport random\nimport heapq\nfrom heapq import *\n" + normal_import_lines = "import sys\nimport time\nimport itertools\nfrom itertools import accumulate, product, permutations, combinations\nimport collections\nfrom collections import Counter, OrderedDict, deque, defaultdict, ChainMap\nfrom functools import lru_cache\nimport math\nfrom math import sqrt, sin, cos, tan, ceil, fabs, floor, gcd, exp, log, log2\nimport fractions\nfrom typing import List, Tuple\nimport numpy as np\nimport random\nimport heapq\nfrom heapq import *\n" # noqa: E501 if debug: print(f"loading test code = {datetime.now().time()}") @@ -334,7 +334,7 @@ def execute_cb_code( if debug: print(f"Standard input runtime error = {e}") if early_stop: - for i in range(index, len(inputs_list)): + for _i in range(index, len(inputs_list)): results.append((False, EXECUTION_RESULTS[-2])) break else: @@ -455,9 +455,11 @@ def execute_std_code( assert exec_code != -3 exec_results[i] = ( exec_code == 1, - EXECUTION_RESULTS[exec_code] - if exec_code > -3 - else EXECUTION_RESULTS[exec_code].format(result.returncode), + ( + EXECUTION_RESULTS[exec_code] + if exec_code > -3 + else EXECUTION_RESULTS[exec_code].format(result.returncode) + ), ) if exec_code >= 0: if debug: diff --git a/skythought/tools/tasks/task_util.py b/skythought/skythought_evals/tasks/task_util.py similarity index 69% rename from skythought/tools/tasks/task_util.py rename to skythought/skythought_evals/tasks/task_util.py index 5a736c8..070a38e 100644 --- a/skythought/tools/tasks/task_util.py +++ b/skythought/skythought_evals/tasks/task_util.py @@ -1,16 +1,18 @@ -import glob -import os +import glob +import os from typing import Dict + def get_tasks(task_root_dir: str) -> Dict[str, str]: - """Returns a dictionary of task names and their corresponding yaml file paths""" + """Returns a dictionary of task names and their corresponding yaml file paths""" # list all yamls in subdirectories name_to_yaml = {} - for yaml_file in glob.glob(os.path.join(task_root_dir, "**", "*.yaml"), recursive=True): + for yaml_file in glob.glob( + os.path.join(task_root_dir, "**", "*.yaml"), recursive=True + ): # arc.yaml -> arc name = os.path.basename(yaml_file).split(".")[0] name_to_yaml[name] = yaml_file - - return name_to_yaml + return name_to_yaml diff --git a/skythought/tools/upload_hub.py b/skythought/skythought_evals/upload_hub.py similarity index 100% rename from skythought/tools/upload_hub.py rename to skythought/skythought_evals/upload_hub.py diff --git a/skythought/skythought_evals/util/__init__.py b/skythought/skythought_evals/util/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/skythought/tools/util/common.py b/skythought/skythought_evals/util/common.py similarity index 100% rename from skythought/tools/util/common.py rename to skythought/skythought_evals/util/common.py diff --git a/skythought/tools/util/math_parsing_util.py b/skythought/skythought_evals/util/math_parsing_util.py similarity index 98% rename from skythought/tools/util/math_parsing_util.py rename to skythought/skythought_evals/util/math_parsing_util.py index 798930f..6872d9f 100644 --- a/skythought/tools/util/math_parsing_util.py +++ b/skythought/skythought_evals/util/math_parsing_util.py @@ -16,7 +16,7 @@ def convert_word_number(text: str) -> str: try: text = str(w2n.word_to_num(text)) - except: + except Exception: pass return text @@ -33,7 +33,7 @@ def _fix_fracs(string): else: try: assert len(substr) >= 2 - except: + except Exception: return string a = substr[0] b = substr[1] @@ -66,7 +66,7 @@ def _fix_a_slash_b(string): assert string == "{}/{}".format(a, b) new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" return new_string - except: + except Exception: return string @@ -227,7 +227,7 @@ def replace_match(match): # Split the string into a list of integers try: integer_list = list(map(int, string.split(","))) - except: + except Exception: integer_list = list(map(int, "-1,-1".split(","))) # Sort the list in ascending order @@ -351,14 +351,14 @@ def parse_digits(num): num = regex.sub(",", "", str(num)) try: return float(num) - except: + except Exception: if num.endswith("%"): num = num[:-1] if num.endswith("\\"): num = num[:-1] try: return float(num) / 100 - except: + except Exception: pass return None @@ -423,7 +423,7 @@ def math_equal( except Exception: continue return False - except: + except Exception: pass if not prediction and prediction not in [0, False]: @@ -573,10 +573,10 @@ def _parse(s): for f in [parse_latex, parse_expr, latex2sympy]: try: return f(s.replace("\\\\", "\\")) - except: + except Exception: try: return f(s) - except: + except Exception: pass return s @@ -587,27 +587,27 @@ def _parse(s): try: if str(a) == str(b) or a == b: return True - except: + except Exception: pass # simplify equal try: if a.equals(b) or simplify(a - b) == 0: return True - except: + except Exception: pass # equation equal try: if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)): return True - except: + except Exception: pass try: if numeric_equal(float(N(a)), float(N(b))): return True - except: + except Exception: pass # matrix @@ -618,7 +618,7 @@ def _parse(s): _b = b.applyfunc(lambda x: round(x, 3)) if _a.equals(_b): return True - except: + except Exception: pass return False diff --git a/skythought/tools/util/model_utils.py b/skythought/skythought_evals/util/model_utils.py similarity index 100% rename from skythought/tools/util/model_utils.py rename to skythought/skythought_evals/util/model_utils.py diff --git a/skythought/tools/util/prompts.py b/skythought/skythought_evals/util/prompts.py similarity index 98% rename from skythought/tools/util/prompts.py rename to skythought/skythought_evals/util/prompts.py index c815db4..a59f7c1 100644 --- a/skythought/tools/util/prompts.py +++ b/skythought/skythought_evals/util/prompts.py @@ -40,7 +40,7 @@ 8: For each positive integer $n$, the Bank of Cape Town issues coins of denomination $\frac1n$. Given a finite collection of such coins (of not necessarily different denominations) with total value at most most $99+\frac{1}{2}$, prove that it is possible to split this collection into $100$ or fewer groups, such that each group has total value at most $1$. (IMO 2014/5) \ 9: Let $k$ be a positive integer and let $S$ be a finite set of odd prime numbers. Prove that there is at most one way (up to rotation and reflection) to place the elements of $S$ around the circle such that the product of any two neighbors is of the form $x^2+x+k$ for some positive integer $x$. (IMO 2022/3) \ 10: Prove that there exists a positive constant $c$ such that the following statement is true: Consider an integer $n > 1$, and a set $\mathcal S$ of $n$ points in the plane such that the distance between any two different points in $\mathcal S$ is at least 1. It follows that there is a line $\ell$ separating $\mathcal S$ such that the distance from any point of $\mathcal S$ to $\ell$ is at least $cn^{-1/3}$. \ - (A line $\ell$ separates a set of points S if some segment joining two points in $\mathcal S$ crosses $\ell$.) (IMO 2020/6)" + (A line $\ell$ separates a set of points S if some segment joining two points in $\mathcal S$ crosses $\ell$.) (IMO 2020/6)" # noqa: E501 convert_prompt = "Another solution is written in an unstructured way. Your job is to convert them into two sections: \ @@ -54,7 +54,7 @@ {example} \ Important: You should almost copy all the contents word-by-word of the original solution. Just convert them into two sections. \ Make sure you include: <|begin_of_slow_thought|>, <|end_of_slow_thought|>, <|begin_of_solution|>,<|end_of_solution|> These four headers explicitly. \ - Content to be converted: {content}" + Content to be converted: {content}" # noqa: E501 convert_prompt_example = ( "<|begin_of_thought|>\n\n" @@ -92,7 +92,7 @@ "I think that's the answer.<|end_of_thought|>\n\n" "<|begin_of_solution|>\n\n" "Mr. Wang leaves home at 6 AM and rides at a speed of 12 km/h, stopping to rest for 6 minutes after every 30 minutes of riding. " - "He arrives at a park 16.8 km away. To determine the angle between the hour and minute hands on his watch when he arrives, we first calculate the total time taken.\n\n" + "He arrives at a park 16.8 km away. To determine the angle between the hour and minute hands on his watch when he arrives, we first calculate the total time taken.\n\n" # noqa: E501 "1. **Riding time without stops**:\n\n" "$$\\text{Time} = \\frac{\\text{Distance}}{\\text{Speed}} = \\frac{16.8 \\text{ km}}{12 \\text{ km/h}} = 1.4 \\text{ hours} = 84 \\text{ minutes}$$\n\n" "2. **Rest periods**:\n\n" @@ -107,7 +107,7 @@ "$$\\text{Angle} = |30H - 5.5M|$$\n\n" " - At 7:36, $H = 7$ and $M = 36$:\n\n" "$$\\text{Angle} = |30 \\times 7 - 5.5 \\times 36| = |210 - 198| = 12 \\text{ degrees}$$\n\n" - "Thus, the angle between the hour and minute hands on his watch is $\\boxed{12}$.<|end_of_solution|>\n" + "Thus, the angle between the hour and minute hands on his watch is $\\boxed{12}$.<|end_of_solution|>\n" # noqa: E501 ) # From https://arxiv.org/pdf/2412.09413 From d3dde5de518ca1c6cc5f20245cae92dbc0007464 Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Mon, 27 Jan 2025 23:16:23 +0000 Subject: [PATCH 25/47] x Signed-off-by: SumanthRH --- setup.py | 11 +++++++++++ .../skythought_evals/tasks/taco/pyext2.py | 19 +++++++------------ 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/setup.py b/setup.py index 3a484f2..cdfebb7 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,15 @@ # setup module skyevals in tools directory +from pathlib import Path + import setuptools + +def get_requirements(): + req_path = Path("skythought/skythought_evals/requirements.txt") + with open(req_path, "r") as f: + return f.read().splitlines() + + setuptools.setup( name="skythought_evals", version="0.0.1", @@ -12,4 +21,6 @@ f"skythought_evals.{pkg}" for pkg in setuptools.find_packages(where="skythought/skythought_evals") ], + install_requires=get_requirements(), + python_requires=">=3.9,<3.12", # pyext doesn't work with python 3.12 ) diff --git a/skythought/skythought_evals/tasks/taco/pyext2.py b/skythought/skythought_evals/tasks/taco/pyext2.py index 514cf2b..42ea666 100644 --- a/skythought/skythought_evals/tasks/taco/pyext2.py +++ b/skythought/skythought_evals/tasks/taco/pyext2.py @@ -42,9 +42,7 @@ "run_main", ] -import inspect # noqa: E402 -import sys # noqa: E402 -import types # noqa: E402 +import sys, inspect, types def __targspec(func, specs, attr="__orig_arg__"): @@ -379,9 +377,7 @@ def from_string(module_name_for_code_eval, docstring, s): class CaseObject(object): - """The object returned by a switch statement. When called, it will return True if the given argument equals its value, else False. - It can be called with multiple parameters, in which case it checks if its value equals any of the arguments. - """ + "The object returned by a switch statement. When called, it will return True if the given argument equals its value, else False. It can be called with multiple parameters, in which case it checks if its value equals any of the arguments." def __init__(self, value): self.value = value @@ -393,7 +389,7 @@ def __call__(self, *args): "res", not self.did_pass and any([self.value == rhs for rhs in args]) ): self.did_match = True - return res # noqa: F821 + return res def quit(self): "Forces all other calls to return False. Equilavent of a ``break`` statement." @@ -434,8 +430,7 @@ def switch(value): def tail_recurse(spec=None): """Remove tail recursion from a function. - :param spec: A function that, when given the arguments, returns a bool indicating whether or not to exit. - If ``None``, tail recursion is always called unless the function returns a value. + :param spec: A function that, when given the arguments, returns a bool indicating whether or not to exit. If ``None,`` tail recursion is always called unless the function returns a value. .. note:: @@ -489,7 +484,8 @@ def annotate(*args, **kwargs): :param kwargs: This is a mapping of argument names to annotations. Note that these are applied *after* the argument list, so any args set that way will be overriden by this mapping. If there is a key named `ret`, that will be the annotation for the function's return value. .. deprecated:: 0.5 - Use :func:`fannotate` instead.""" + Use :func:`fannotate` instead. + """ def _wrap(f): if not hasattr(f, "__annotations__"): @@ -530,8 +526,7 @@ def _wrap(f): def safe_unpack(seq, ln, fill=None): - """Safely unpack a sequence to length `ln`, without raising ValueError. Based on Lua's method of unpacking. - Empty values will be filled in with `fill`, while any extra values will be cut off. + """Safely unpack a sequence to length `ln`, without raising ValueError. Based on Lua's method of unpacking. Empty values will be filled in with `fill`, while any extra values will be cut off. :param seq: The sequence to unpack. From 9d64fd160d48c2c8b56dc71a294d190eba0b4e3e Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Mon, 27 Jan 2025 23:34:22 +0000 Subject: [PATCH 26/47] more comments Signed-off-by: SumanthRH --- skythought/skythought_evals/tasks/amc23/amc23.yaml | 5 +++-- skythought/skythought_evals/tasks/apps/apps.yaml | 5 +++-- .../skythought_evals/tasks/livecodebench/livecodebench.yaml | 5 +++-- skythought/skythought_evals/tasks/numina/numina.yaml | 3 ++- skythought/skythought_evals/tasks/taco/taco.yaml | 3 ++- 5 files changed, 13 insertions(+), 8 deletions(-) diff --git a/skythought/skythought_evals/tasks/amc23/amc23.yaml b/skythought/skythought_evals/tasks/amc23/amc23.yaml index a7ece4d..f0933ed 100644 --- a/skythought/skythought_evals/tasks/amc23/amc23.yaml +++ b/skythought/skythought_evals/tasks/amc23/amc23.yaml @@ -5,7 +5,8 @@ dataset_kwargs: dataset_split: train question_key: problem answer_key: answer -preprocess_config: - difficulty: null +# Optionally, you can filter the dataset by difficulty +# preprocess_config: +# difficulty: easy templating_parameters: template: "Return your final response within \\boxed{{}}. {problem}" diff --git a/skythought/skythought_evals/tasks/apps/apps.yaml b/skythought/skythought_evals/tasks/apps/apps.yaml index 455d027..04ca7d1 100644 --- a/skythought/skythought_evals/tasks/apps/apps.yaml +++ b/skythought/skythought_evals/tasks/apps/apps.yaml @@ -12,5 +12,6 @@ templating_parameters: without_fn_name_template: "Generate an executable Python function generated from the given prompt. Return the function body without invoking it at the final solution. {prompt}" # Add starter code on top of the initial template with_starter_code_template: "{input}\n{starter_code}" -# preprocess_config; -# difficulty: easy # optional filter config \ No newline at end of file +# Optionally, you can filter the dataset by difficulty +# preprocess_config: +# difficulty: easy diff --git a/skythought/skythought_evals/tasks/livecodebench/livecodebench.yaml b/skythought/skythought_evals/tasks/livecodebench/livecodebench.yaml index a8347fd..ec060be 100644 --- a/skythought/skythought_evals/tasks/livecodebench/livecodebench.yaml +++ b/skythought/skythought_evals/tasks/livecodebench/livecodebench.yaml @@ -10,5 +10,6 @@ answer_key: null templating_parameters: stdin_template: "Generate an executable Python function generated from the given prompt. The function should take stdin as input and print the output. Simply call the function after the definition. {prompt}" non_stdin_template: "Generate an executable Python function generated from the given prompt. Return the function body without invoking it at the final solution. {prompt}" -preprocess_config: null -# difficulty: easy # use all by default \ No newline at end of file +# Optionally, you can filter the dataset by difficulty +# preprocess_config: +# difficulty: easy diff --git a/skythought/skythought_evals/tasks/numina/numina.yaml b/skythought/skythought_evals/tasks/numina/numina.yaml index b8cbdff..7f4c8f8 100644 --- a/skythought/skythought_evals/tasks/numina/numina.yaml +++ b/skythought/skythought_evals/tasks/numina/numina.yaml @@ -6,5 +6,6 @@ question_key: problem answer_key: solution templating_parameters: template: "Return your final response within \\boxed{{}}. {prompt}" +# Optionally, you can filter the dataset by difficulty # preprocess_config: -# difficulty: null \ No newline at end of file +# difficulty: easy diff --git a/skythought/skythought_evals/tasks/taco/taco.yaml b/skythought/skythought_evals/tasks/taco/taco.yaml index 85d317b..d411e72 100644 --- a/skythought/skythought_evals/tasks/taco/taco.yaml +++ b/skythought/skythought_evals/tasks/taco/taco.yaml @@ -14,6 +14,7 @@ templating_parameters: stdin_template: "{input}\nUse Standard Input format\nANSWER:\n" # call template is used when there is starter code or fn_name call_template: "{input}\nUse Call-Based format\nANSWER:\n" +# Optionally, you can filter the dataset by difficulty # preprocess_config: -# difficulty: null +# difficulty: easy From d55d2a80f8d9f469551c6b39ca7c3e021a1f535b Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Tue, 28 Jan 2025 01:43:23 +0000 Subject: [PATCH 27/47] x Signed-off-by: SumanthRH --- .pre-commit-config.yaml | 3 +- skythought/skythought_evals/README.md | 40 +- .../skythought_evals/inference_and_check.py | 18 +- .../skythought_evals/tasks/apps/apps.yaml | 3 +- .../tasks/apps/apps_handler.py | 1 + .../livecodebench/livecodebench_handler.py | 2 +- .../skythought_evals/tasks/taco/pyext2.py | 481 +++++++----------- .../skythought_evals/tasks/taco/taco.yaml | 2 +- .../preprocessing.py | 7 +- 9 files changed, 219 insertions(+), 338 deletions(-) rename tests/{tools => skythought_evals}/preprocessing.py (89%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2f233aa..d2655bf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,6 +4,7 @@ repos: hooks: - id: ruff args: [ --fix, --exit-non-zero-on-fix ] + # NOTE (sumanthrh): Many of the files excluded here are used for validating code generation, and linters do not recognize some of the logic in these files. skythought/train is excluded for now because it's a fork of Llamafactory exclude: (^skythought/train|skythought_evals/tasks/taco/pyext2\.py|skythought_evals/tasks/taco/taco_util\.py|skythought_evals/tasks/apps/apps_util\.py|skythought_evals/util/prompts\.py|skythought_evals/util/model_utils\.py)$ @@ -12,4 +13,4 @@ repos: rev: 24.10.0 hooks: - id: black - exclude: ^skythought/train + exclude: (^skythought/train|skythought_evals/tasks/taco/pyext2\.py)$ diff --git a/skythought/skythought_evals/README.md b/skythought/skythought_evals/README.md index 58c7e0e..26d2f15 100644 --- a/skythought/skythought_evals/README.md +++ b/skythought/skythought_evals/README.md @@ -27,17 +27,17 @@ The expected output is labeled_source_0_-1.json. We also provide instructions to Inference the results from QwQ on several datasets. In preview version, we use data from the following dataset. ```shell -python inference_and_check.py --dataset APPS --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split test --source all --result-dir $SKYT_HOME/data --inference +python inference_and_check.py --task apps --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split test --source all --result-dir $SKYT_HOME/data --inference -python inference_and_check.py --dataset TACO --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source MEDIUM --filter-difficulty --result-dir $SKYT_HOME/data --inference +python inference_and_check.py --task taco --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source MEDIUM --filter-difficulty --result-dir $SKYT_HOME/data --inference -python inference_and_check.py --dataset TACO --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split test --source all --result-dir $SKYT_HOME/data --inference +python inference_and_check.py --task taco --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split test --source all --result-dir $SKYT_HOME/data --inference -python inference_and_check.py --dataset NUMINA --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source math --filter-difficulty --result-dir $SKYT_HOME/data --inference +python inference_and_check.py --task numina --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source math --filter-difficulty math --result-dir $SKYT_HOME/data --inference -python inference_and_check.py --dataset NUMINA --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source amc_aime --filter-difficulty --result-dir $SKYT_HOME/data --inference +python inference_and_check.py --task numina --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source amc_aime --filter-difficulty --result-dir $SKYT_HOME/data --inference -python inference_and_check.py --dataset NUMINA --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source olympiads --end 20000 --filter-difficulty --result-dir $SKYT_HOME/data --inference +python inference_and_check.py --task numina --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source olympiads --end 20000 --filter-difficulty --result-dir $SKYT_HOME/data --inference ``` ### Step 2: Format the response @@ -48,7 +48,7 @@ python convert_format.py --input_dir $SKYT_HOME/data --keys keys.txt ### Step 3: Reject Sampling on the formatted data (Example Usage with previous script) ```shell -python inference_and_check.py --dataset APPS --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split test --source all --result-dir $SKYT_HOME/data --check +python inference_and_check.py --task apps --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split test --source all --result-dir $SKYT_HOME/data --check ``` Similar for other datasets. @@ -67,24 +67,24 @@ Currently we support distill and reject sampling from various self-hosted models #### Example Usage ```shell -python inference_and_check.py --dataset APPS --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split test --source all --result-dir $SKYT_HOME/data +python inference_and_check.py --task apps --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split test --source all --result-dir $SKYT_HOME/data -python inference_and_check.py --dataset TACO --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source MEDIUM --filter-difficulty --result-dir $SKYT_HOME/data +python inference_and_check.py --task taco --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source MEDIUM --filter-difficulty --result-dir $SKYT_HOME/data -python inference_and_check.py --dataset TACO --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split test --source all --result-dir $SKYT_HOME/data +python inference_and_check.py --task taco --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split test --source all --result-dir $SKYT_HOME/data -python inference_and_check.py --dataset NUMINA --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source math --filter-difficulty --result-dir $SKYT_HOME/data --math_difficulty_lower_bound 4 --math_difficulty_upper_bound 9 +python inference_and_check.py --task numina --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source math --filter-difficulty --result-dir $SKYT_HOME/data --math_difficulty_lower_bound 4 --math_difficulty_upper_bound 9 -python inference_and_check.py --dataset NUMINA --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source amc_aime --filter-difficulty --result-dir $SKYT_HOME/data --math_difficulty_lower_bound 1 --math_difficulty_upper_bound 9 +python inference_and_check.py --task numina --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source amc_aime --filter-difficulty --result-dir $SKYT_HOME/data --math_difficulty_lower_bound 1 --math_difficulty_upper_bound 9 -python inference_and_check.py --dataset NUMINA --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source olympiads --end 20000 --filter-difficulty --result-dir $SKYT_HOME/data --math_difficulty_lower_bound 9 --math_difficulty_upper_bound 9 +python inference_and_check.py --task numina --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source olympiads --end 20000 --filter-difficulty --result-dir $SKYT_HOME/data --math_difficulty_lower_bound 9 --math_difficulty_upper_bound 9 ``` #### Best-of-N Inference and Check ```bash -python inference_and_check.py --dataset MATH500 --model Qwen/Qwen2-7B-Instruct --tp 4 --max_tokens 4096 --split test --result-dir ./ --inference --temperatures 0.7 --n 64 -python inference_and_check.py --dataset MATH500 --model Qwen/Qwen2-7B-Instruct --tp 4 --max_tokens 4096 --split test --result-dir ./ --check --temperatures 0.7 --n 8 +python inference_and_check.py --task math500 --model Qwen/Qwen2-7B-Instruct --tp 4 --max_tokens 4096 --split test --result-dir ./ --inference --temperatures 0.7 --n 64 +python inference_and_check.py --task math500 --model Qwen/Qwen2-7B-Instruct --tp 4 --max_tokens 4096 --split test --result-dir ./ --check --temperatures 0.7 --n 8 ``` ### Benchmark Evaluations @@ -95,12 +95,12 @@ We provide a wrapper script `eval.py` to conveniently run reasoning benchmarks. **NOTE**: For reproducing `Sky-T1-32B-Preview` results on `AIME` and `GPQADiamond` dataset, pass in temperatures as `0.7`. ```shell -python eval.py --model NovaSky-AI/Sky-T1-32B-Preview --evals=AIME,GPQADiamond --tp=8 --output_file=results.txt --temperatures 0.7 +python eval.py --model NovaSky-AI/Sky-T1-32B-Preview --evals=aime,gpqa_diamond --tp=8 --output_file=results.txt --temperatures 0.7 ``` #### Example Usage ```shell -python eval.py --model Qwen/QwQ-32B-Preview --evals=AIME,MATH500,GPQADiamond --tp=8 --output_file=results.txt +python eval.py --model Qwen/QwQ-32B-Preview --evals=aime,math500,gpqa_diamond --tp=8 --output_file=results.txt ``` Example result: `{"AIME": , "MATH500": , "GPQADiamond": }` @@ -111,14 +111,14 @@ The file `response_rewrite.py` provides a pipeline for filtering and rewriting r To use our preference optimization pipeline, first generate and score multiple responses using `inference_and_check.py`. For example: ```shell -python inference_and_check.py --inference --dataset MATH500 --model Qwen/Qwen2-7B-Instruct --tp 4 --max_tokens 4096 --split test --result-dir ./ --temperatures 0.7 --n 8 -python inference_and_check.py --check --dataset MATH500 --model Qwen/Qwen2-7B-Instruct --tp 4 --max_tokens 4096 --split test --result-dir ./ --temperatures 0.7 --n 8 +python inference_and_check.py --inference --task math500 --model Qwen/Qwen2-7B-Instruct --tp 4 --max_tokens 4096 --result-dir ./ --temperatures 0.7 --n 8 +python inference_and_check.py --check --task math500 --model Qwen/Qwen2-7B-Instruct --tp 4 --max_tokens 4096 --result-dir ./ --temperatures 0.7 --n 8 ``` Then, use `response_rewrite.py` to process the responses into preference pairs. By default, the shortest correct responses will be used as positive examples and the longest correct responses will be used as negative samples. The argument `--SILC` can be used to also include short incorrect responses as negative examples and long correct repsonses as positive samples. ```shell -python response_rewrite.py --SILC --rewrite-model meta-llama/Meta-Llama-3-8B-Instruct --target-model NovaSky-AI/Sky-T1-32B-Preview --dataset [PATH_TO_GENERATED_RESPONSES] --result-dir ./ --checkpoint --tp 8 +python response_rewrite.py --SILC --rewrite-model meta-llama/Meta-Llama-3-8B-Instruct --target-model NovaSky-AI/Sky-T1-32B-Preview --task [PATH_TO_GENERATED_RESPONSES] --result-dir ./ --checkpoint --tp 8 ``` The `--checkpoint` argument can optionally be used to save intermediate files of the processed data between steps, in case of failure. diff --git a/skythought/skythought_evals/inference_and_check.py b/skythought/skythought_evals/inference_and_check.py index df60734..7c3f10e 100644 --- a/skythought/skythought_evals/inference_and_check.py +++ b/skythought/skythought_evals/inference_and_check.py @@ -7,9 +7,9 @@ import numpy as np from openai import OpenAI +from skythought_evals.tasks import TASK_HANDLER_MAP, NUMINATaskHandler, TaskHandler +from skythought_evals.tasks.task_util import get_tasks from skythought_evals.util.model_utils import MODEL_TO_NAME, SYSTEM_PROMPT -from tasks import TASK_HANDLER_MAP, NUMINATaskHandler, TaskHandler -from tasks.task_util import get_tasks from tqdm import tqdm from vllm import LLM, SamplingParams @@ -446,7 +446,7 @@ def main(): "--split", type=str, default=None, - help="Split to use for apps (e.g., train, test).", + help="Split to use for the dataset (e.g., train, test).", ) parser.add_argument("--source", type=str, help="Source for the dataset.") parser.add_argument("--start", type=int, default=0, help="Start index.") @@ -506,7 +506,8 @@ def main(): # Currently kept here for consistency with old code args.split = args.split if args.split else handler.task_config.dataset_split args.source = args.source if args.source else handler.task_config.dataset_source - + if not args.filter_difficulty and handler.task_config.preprocess_config: + args.filter_difficulty = handler.task_config.preprocess_config.difficulty # create result dir if not exists if args.result_dir and not os.path.exists(args.result_dir): os.makedirs(args.result_dir) @@ -516,12 +517,12 @@ def main(): ): result_file = os.path.join( args.result_dir, - f"{MODEL_TO_NAME[args.model]}_{args.task}_{args.split}_{args.source}_{args.start}_{args.end}_{args.math_difficulty_lower_bound}_{args.math_difficulty_upper_bound}.json", + f"{MODEL_TO_NAME[args.model]}_{args.task}_{args.split}_{args.source}_{args.filter_difficulty}_{args.start}_{args.end}_{args.math_difficulty_lower_bound}_{args.math_difficulty_upper_bound}.json", ) else: result_file = os.path.join( args.result_dir, - f"{MODEL_TO_NAME[args.model]}_{args.task}_{args.split}_{args.source}_{args.start}_{args.end}.json", + f"{MODEL_TO_NAME[args.model]}_{args.task}_{args.split}_{args.source}_{args.filter_difficulty}_{args.start}_{args.end}.json", ) if args.check: @@ -531,11 +532,12 @@ def main(): or args.math_difficulty_upper_bound is not None ): converted_file = ( - f"{args.result_dir}/converted_{MODEL_TO_NAME[args.model]}_{args.task}_{args.split}_{args.source}_{args.start}_{args.end}" + f"{args.result_dir}/converted_{MODEL_TO_NAME[args.model]}_{args.task}_{args.split}_{args.source}_{args.filter_difficulty}_{args.start}_{args.end}" + f"_{args.math_difficulty_lower_bound}_{args.math_difficulty_upper_bound}.json" ) else: - converted_file = f"{args.result_dir}/converted_{MODEL_TO_NAME[args.model]}_{args.task}_{args.split}_{args.source}_{args.start}_{args.end}.json" + converted_file = f"{args.result_dir}/converted_{MODEL_TO_NAME[args.model]}_{args.task}_{args.split}_{args.source}_{args.filter_difficulty}" + f"_{args.start}_{args.end}.json" if os.path.exists(converted_file): result_file = converted_file perform_check(handler, temperatures, result_file, args) diff --git a/skythought/skythought_evals/tasks/apps/apps.yaml b/skythought/skythought_evals/tasks/apps/apps.yaml index 04ca7d1..eeb1569 100644 --- a/skythought/skythought_evals/tasks/apps/apps.yaml +++ b/skythought/skythought_evals/tasks/apps/apps.yaml @@ -1,8 +1,9 @@ handler: apps dataset_path: codeparrot/apps +dataset_source: all dataset_kwargs: trust_remote_code: true -dataset_split: train +dataset_split: test question_key: question answer_key: null # preprocess_config: diff --git a/skythought/skythought_evals/tasks/apps/apps_handler.py b/skythought/skythought_evals/tasks/apps/apps_handler.py index d9b5b20..63c6b39 100644 --- a/skythought/skythought_evals/tasks/apps/apps_handler.py +++ b/skythought/skythought_evals/tasks/apps/apps_handler.py @@ -21,6 +21,7 @@ def generate_prompt(self, test_case, prompt, starter_code=None): _input = self.task_config.templating_parameters[ "without_fn_name_template" ].format(prompt=prompt) + if starter_code is not None: _input = self.task_config.templating_parameters[ "with_starter_code_template" diff --git a/skythought/skythought_evals/tasks/livecodebench/livecodebench_handler.py b/skythought/skythought_evals/tasks/livecodebench/livecodebench_handler.py index 5e50ff3..8f99ce2 100644 --- a/skythought/skythought_evals/tasks/livecodebench/livecodebench_handler.py +++ b/skythought/skythought_evals/tasks/livecodebench/livecodebench_handler.py @@ -104,7 +104,7 @@ def load_and_filter_dataset( # Filter by CLI or config if filter_difficulty or self.task_config.preprocess_config.difficulty: difficulty = ( - source + filter_difficulty if filter_difficulty else self.task_config.preprocess_config.difficulty ) diff --git a/skythought/skythought_evals/tasks/taco/pyext2.py b/skythought/skythought_evals/tasks/taco/pyext2.py index 42ea666..ff1593d 100644 --- a/skythought/skythought_evals/tasks/taco/pyext2.py +++ b/skythought/skythought_evals/tasks/taco/pyext2.py @@ -1,4 +1,4 @@ -""" +''' Copyright (C) 2014 Ryan Gonzalez @@ -18,58 +18,37 @@ COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -""" +''' g_backup = globals().copy() -__version__ = "0.7" - -__all__ = [ - "overload", - "RuntimeModule", - "switch", - "tail_recurse", - "copyfunc", - "set_docstring", - "annotate", - "safe_unpack", - "modify_function", - "assign", - "fannotate", - "compare_and_swap", - "is_main", - "call_if_main", - "run_main", -] +__version__ = '0.7' -import sys, inspect, types +__all__ = ['overload', 'RuntimeModule', 'switch', 'tail_recurse', 'copyfunc', 'set_docstring', 'annotate', 'safe_unpack', 'modify_function', 'assign', 'fannotate', 'compare_and_swap', 'is_main', 'call_if_main', 'run_main'] +import sys, inspect, types -def __targspec(func, specs, attr="__orig_arg__"): - if hasattr(func, "__is_overload__") and func.__is_overload__: +def __targspec(func, specs, attr='__orig_arg__'): + if hasattr(func, '__is_overload__') and func.__is_overload__: return getattr(func, attr) return specs(func) - def set_docstring(doc): - """A simple decorator to set docstrings. + '''A simple decorator to set docstrings. - :param doc: The docstring to tie to the function. + :param doc: The docstring to tie to the function. - Example:: - - @set_docstring('This is a docstring') - def myfunc(x): - pass""" + Example:: + @set_docstring('This is a docstring') + def myfunc(x): + pass''' def _wrap(f): f.__doc__ = doc return f - return _wrap - -__modify_function_doc = """ +__modify_function_doc = ''' Creates a copy of a function, changing its attributes. :param globals: Will be added to the function's globals. @@ -83,100 +62,63 @@ def _wrap(f): :param closure: The new function closure. Set to ``None`` to use the function's original closure. .. warning:: This function can be potentially dangerous. -""" - +''' def copyfunc(f): - """Copies a funcion. + '''Copies a funcion. - :param f: The function to copy. + :param f: The function to copy. - :return: The copied function. - - .. deprecated:: 0.4 - Use :func:`modify_function` instead. - """ - return modify_function(f) + :return: The copied function. + .. deprecated:: 0.4 + Use :func:`modify_function` instead. + ''' + return modify_function(f) if sys.version_info.major == 3: - @set_docstring(__modify_function_doc) - def modify_function( - f, globals={}, name=None, code=None, defaults=None, closure=None - ): - if code is None: - code = f.__code__ - if name is None: - name = f.__name__ - if defaults is None: - defaults = f.__defaults__ - if closure is None: - closure = f.__closure__ - newf = types.FunctionType( - code, - dict(f.__globals__, **globals), - name=name, - argdefs=defaults, - closure=closure, - ) + def modify_function(f, globals={}, name=None, code=None, defaults=None, + closure=None): + if code is None: code = f.__code__ + if name is None: name = f.__name__ + if defaults is None: defaults = f.__defaults__ + if closure is None: closure = f.__closure__ + newf = types.FunctionType(code, dict(f.__globals__, **globals), name=name, + argdefs=defaults, closure=closure) newf.__dict__.update(f.__dict__) return newf - def argspec(f): return inspect.getfullargspec(f) - ofullargspec = inspect.getfullargspec - def _fullargspec(func): return __targspec(func, ofullargspec) - inspect.getfullargspec = _fullargspec - - def _exec(m, g): - exec(m, g) - + def _exec(m,g): exec(m,g) else: - @set_docstring(__modify_function_doc) - def modify_function( - f, globals={}, name=None, code=None, defaults=None, closure=None - ): - if code is None: - code = f.func_code - if name is None: - name = f.__name__ - if defaults is None: - defaults = f.func_defaults - if closure is None: - closure = f.func_closure - newf = types.FunctionType( - code, - dict(f.func_globals, **globals), - name=name, - argdefs=defaults, - closure=closure, - ) + def modify_function(f, globals={}, name=None, code=None, defaults=None, + closure=None): + if code is None: code = f.func_code + if name is None: name = f.__name__ + if defaults is None: defaults = f.func_defaults + if closure is None: closure = f.func_closure + newf = types.FunctionType(code, dict(f.func_globals, **globals), name=name, + argdefs=defaults, closure=closure) newf.__dict__.update(f.__dict__) return newf - def argspec(f): return inspect.getargspec(f) - - eval(compile("def _exec(m,g): exec m in g", "", "exec")) - + eval(compile('def _exec(m,g): exec m in g', '', 'exec')) def _gettypes(args): return tuple(map(type, args)) - oargspec = inspect.getargspec - def _argspec(func): return __targspec(func, oargspec) - inspect.getargspec = _argspec try: @@ -186,60 +128,50 @@ def _argspec(func): else: # Replace IPython's argspec oipyargspec = IPython.core.oinspect.getargspec - def _ipyargspec(func): - return __targspec(func, oipyargspec, "__orig_arg_ipy__") - + return __targspec(func, oipyargspec, '__orig_arg_ipy__') IPython.core.oinspect.getargspec = _ipyargspec - class overload(object): - """Simple function overloading in Python.""" - + '''Simple function overloading in Python.''' _items = {} _types = {} - @classmethod def argc(self, argc=None): - """Overloads a function based on the specified argument count. + '''Overloads a function based on the specified argument count. - :param argc: The argument count. Defaults to ``None``. If ``None`` is given, automatically compute the argument count from the given function. + :param argc: The argument count. Defaults to ``None``. If ``None`` is given, automatically compute the argument count from the given function. - .. note:: + .. note:: - Keyword argument counts are NOT checked! In addition, when the argument count is automatically calculated, the keyword argument count is also ignored! + Keyword argument counts are NOT checked! In addition, when the argument count is automatically calculated, the keyword argument count is also ignored! - Example:: + Example:: - @overload.argc() - def func(a): - print 'Function 1 called' + @overload.argc() + def func(a): + print 'Function 1 called' - @overload.argc() - def func(a, b): - print 'Function 2 called' + @overload.argc() + def func(a, b): + print 'Function 2 called' - func(1) # Calls first function - func(1, 2) # Calls second function - func() # Raises error - """ + func(1) # Calls first function + func(1, 2) # Calls second function + func() # Raises error + ''' # Python 2 UnboundLocalError fix - argc = {"argc": argc} - + argc = {'argc': argc} def _wrap(f): def _newf(*args, **kwargs): if len(args) not in self._items[f.__name__]: - raise TypeError( - "No overload of function '%s' that takes %d args" - % (f.__name__, len(args)) - ) + raise TypeError("No overload of function '%s' that takes %d args" % (f.__name__, len(args))) return self._items[f.__name__][len(args)](*args, **kwargs) - if f.__name__ not in self._items: self._items[f.__name__] = {} - if argc["argc"] is None: - argc["argc"] = len(argspec(f).args) - self._items[f.__name__][argc["argc"]] = f + if argc['argc'] is None: + argc['argc'] = len(argspec(f).args) + self._items[f.__name__][argc['argc']] = f _newf.__name__ = f.__name__ _newf.__doc__ = f.__doc__ _newf.__is_overload__ = True @@ -247,68 +179,54 @@ def _newf(*args, **kwargs): if IPython: _newf.__orig_arg_ipy__ = IPython.core.oinspect.getargspec(f) return _newf - return _wrap - @classmethod def args(self, *argtypes, **kw): - """Overload a function based on the specified argument types. + '''Overload a function based on the specified argument types. - :param argtypes: The argument types. If None is given, get the argument types from the function annotations(Python 3 only) - :param kw: Can only contain 1 argument, `is_cls`. If True, the function is assumed to be part of a class. + :param argtypes: The argument types. If None is given, get the argument types from the function annotations(Python 3 only) + :param kw: Can only contain 1 argument, `is_cls`. If True, the function is assumed to be part of a class. - Example:: + Example:: - @overload.args(str) - def func(s): - print 'Got string' + @overload.args(str) + def func(s): + print 'Got string' - @overload.args(int, str) - def func(i, s): - print 'Got int and string' + @overload.args(int, str) + def func(i, s): + print 'Got int and string' - @overload.args() - def func(i:int): # A function annotation example - print 'Got int' + @overload.args() + def func(i:int): # A function annotation example + print 'Got int' - func('s') - func(1) - func(1, 's') - func(True) # Raises error - """ + func('s') + func(1) + func(1, 's') + func(True) # Raises error + ''' # Python 2 UnboundLocalError fix...again! - argtypes = {"args": tuple(argtypes)} - + argtypes = {'args': tuple(argtypes)} def _wrap(f): def _newf(*args): if len(kw) == 0: cargs = args - elif len(kw) == 1 and "is_cls" in kw and kw["is_cls"]: + elif len(kw) == 1 and 'is_cls' in kw and kw['is_cls']: cargs = args[1:] else: - raise ValueError("Invalid keyword args specified") + raise ValueError('Invalid keyword args specified') if _gettypes(cargs) not in self._types[f.__name__]: - raise TypeError( - "No overload of function '%s' that takes '%s' types and %d arg(s)" - % (f.__name__, _gettypes(cargs), len(cargs)) - ) + raise TypeError("No overload of function '%s' that takes '%s' types and %d arg(s)" % (f.__name__, _gettypes(cargs), len(cargs))) return self._types[f.__name__][_gettypes(cargs)](*args) - if f.__name__ not in self._types: self._types[f.__name__] = {} - if len(argtypes["args"]) == 1 and argtypes["args"][0] is None: + if len(argtypes['args']) == 1 and argtypes['args'][0] is None: aspec = argspec(f) - argtypes["args"] = tuple( - map( - lambda x: x[1], - sorted( - aspec.annotations.items(), - key=lambda x: aspec.args.index(x[0]), - ), - ) - ) - self._types[f.__name__][argtypes["args"]] = f + argtypes['args'] = tuple(map(lambda x: x[1], sorted( + aspec.annotations.items(), key=lambda x: aspec.args.index(x[0])))) + self._types[f.__name__][argtypes['args']] = f _newf.__name__ = f.__name__ _newf.__doc__ = f.__doc__ _newf.__is_overload__ = True @@ -316,146 +234,119 @@ def _newf(*args): if IPython: _newf.__orig_arg_ipy__ = IPython.core.oinspect.getargspec(f) return _newf - return _wrap - class _RuntimeModule(object): - "Create a module object at runtime and insert it into sys.path. If called, same as :py:func:`from_objects`." - + 'Create a module object at runtime and insert it into sys.path. If called, same as :py:func:`from_objects`.' def __call__(self, *args, **kwargs): return self.from_objects(*args, **kwargs) - @staticmethod @overload.argc(1) def from_objects(module_name_for_code_eval, **d): - return _RuntimeModule.from_objects(module_name_for_code_eval, "", **d) - + return _RuntimeModule.from_objects(module_name_for_code_eval, '', **d) @staticmethod @overload.argc(2) def from_objects(module_name_for_code_eval, docstring, **d): - """Create a module at runtime from `d`. + '''Create a module at runtime from `d`. - :param name: The module name. + :param name: The module name. - :param docstring: Optional. The module's docstring. + :param docstring: Optional. The module's docstring. - :param \*\*d: All the keyword args, mapped from name->value. + :param \*\*d: All the keyword args, mapped from name->value. - Example: ``RuntimeModule.from_objects('name', 'doc', a=1, b=2)``""" + Example: ``RuntimeModule.from_objects('name', 'doc', a=1, b=2)``''' module = types.ModuleType(module_name_for_code_eval, docstring) module.__dict__.update(d) - module.__file__ = "" + module.__file__ = '' sys.modules[module_name_for_code_eval] = module return module - @staticmethod @overload.argc(2) def from_string(module_name_for_code_eval, s): - return _RuntimeModule.from_string(module_name_for_code_eval, "", s) - + return _RuntimeModule.from_string(module_name_for_code_eval, '', s) @staticmethod @overload.argc(3) def from_string(module_name_for_code_eval, docstring, s): - """Create a module at runtime from `s``. + '''Create a module at runtime from `s``. - :param name: The module name. + :param name: The module name. - :param docstring: Optional. The module docstring. + :param docstring: Optional. The module docstring. - :param s: A string containing the module definition.""" + :param s: A string containing the module definition.''' g = {} _exec(s, g) - return _RuntimeModule.from_objects( - module_name_for_code_eval, - docstring, - **dict(filter(lambda x: x[0] not in g_backup, g.items())) - ) - + return _RuntimeModule.from_objects(module_name_for_code_eval, docstring, **dict(filter(lambda x: x[0] not in g_backup, g.items()))) RuntimeModule = _RuntimeModule() - class CaseObject(object): - "The object returned by a switch statement. When called, it will return True if the given argument equals its value, else False. It can be called with multiple parameters, in which case it checks if its value equals any of the arguments." - + 'The object returned by a switch statement. When called, it will return True if the given argument equals its value, else False. It can be called with multiple parameters, in which case it checks if its value equals any of the arguments.' def __init__(self, value): self.value = value self.did_match = False self.did_pass = False - def __call__(self, *args): - if assign( - "res", not self.did_pass and any([self.value == rhs for rhs in args]) - ): + if assign('res', not self.did_pass and any([self.value == rhs for rhs in args])): self.did_match = True return res - def quit(self): - "Forces all other calls to return False. Equilavent of a ``break`` statement." + 'Forces all other calls to return False. Equilavent of a ``break`` statement.' self.did_pass = True - def default(self): "Executed if quit wasn't called." return not self.did_match and not self.did_pass - def __iter__(self): yield self - def __enter__(self): return self - def __exit__(self, *args): pass - def switch(value): - """A Python switch statement implementation that is used with a ``with`` statement. + '''A Python switch statement implementation that is used with a ``with`` statement. - :param value: The value to "switch". + :param value: The value to "switch". - ``with`` statement example:: + ``with`` statement example:: - with switch('x'): - if case(1): print 'Huh?' - if case('x'): print 'It works!!!' + with switch('x'): + if case(1): print 'Huh?' + if case('x'): print 'It works!!!' - .. warning:: If you modify a variable named "case" in the same scope that you use the ``with`` statement version, you will get an UnboundLocalError. The soluction is to use ``with switch('x') as case:`` instead of ``with switch('x'):``. - """ + .. warning:: If you modify a variable named "case" in the same scope that you use the ``with`` statement version, you will get an UnboundLocalError. The soluction is to use ``with switch('x') as case:`` instead of ``with switch('x'):``.''' res = CaseObject(value) - inspect.stack()[1][0].f_globals["case"] = res + inspect.stack()[1][0].f_globals['case'] = res return res - def tail_recurse(spec=None): - """Remove tail recursion from a function. + '''Remove tail recursion from a function. - :param spec: A function that, when given the arguments, returns a bool indicating whether or not to exit. If ``None,`` tail recursion is always called unless the function returns a value. + :param spec: A function that, when given the arguments, returns a bool indicating whether or not to exit. If ``None,`` tail recursion is always called unless the function returns a value. - .. note:: + .. note:: - This function has a slight overhead that is noticable when using timeit. Only use it if the function has a possibility of going over the recursion limit. + This function has a slight overhead that is noticable when using timeit. Only use it if the function has a possibility of going over the recursion limit. - .. warning:: + .. warning:: - This function will BREAK any code that either uses any recursion other than tail recursion or calls itself multiple times. For example, ``def x(): return x()+1`` will fail. + This function will BREAK any code that either uses any recursion other than tail recursion or calls itself multiple times. For example, ``def x(): return x()+1`` will fail. - Example:: + Example:: - @tail_recurse() - def add(a, b): - if a == 0: return b - return add(a-1, b+1) - - add(10000000, 1) # Doesn't max the recursion limit. - """ + @tail_recurse() + def add(a, b): + if a == 0: return b + return add(a-1, b+1) + add(10000000, 1) # Doesn't max the recursion limit. + ''' def _wrap(f): class TailRecursion(Exception): def __init__(self, args, kwargs): self.args = args self.kwargs = kwargs - def _newf(*args, **kwargs): if inspect.stack()[1][3] == f.__name__: if (spec and spec(args)) or not spec: @@ -469,138 +360,122 @@ def _newf(*args, **kwargs): continue else: return res - _newf.__doc__ = f.__doc__ return _newf - return _wrap - def annotate(*args, **kwargs): - """Set function annotations using decorators. + '''Set function annotations using decorators. - :param args: This is a list of annotations for the function, in the order of the function's parameters. For example, ``annotate('Annotation 1', 'Annotation 2')`` will set the annotations of parameter 1 of the function to ``Annotation 1``. + :param args: This is a list of annotations for the function, in the order of the function's parameters. For example, ``annotate('Annotation 1', 'Annotation 2')`` will set the annotations of parameter 1 of the function to ``Annotation 1``. - :param kwargs: This is a mapping of argument names to annotations. Note that these are applied *after* the argument list, so any args set that way will be overriden by this mapping. If there is a key named `ret`, that will be the annotation for the function's return value. - - .. deprecated:: 0.5 - Use :func:`fannotate` instead. - """ + :param kwargs: This is a mapping of argument names to annotations. Note that these are applied *after* the argument list, so any args set that way will be overriden by this mapping. If there is a key named `ret`, that will be the annotation for the function's return value. + .. deprecated:: 0.5 + Use :func:`fannotate` instead. +''' def _wrap(f): - if not hasattr(f, "__annotations__"): + if not hasattr(f, '__annotations__'): f.__annotations__ = {} - if "ret" in kwargs: - f.__annotations__["return"] = kwargs.pop("ret") + if 'ret' in kwargs: + f.__annotations__['return'] = kwargs.pop('ret') f.__annotations__.update(dict(zip(argspec(f).args, args))) f.__annotations__.update(kwargs) return f - return _wrap - def fannotate(*args, **kwargs): - """Set function annotations using decorators. - - :param \*args: The first positional argument is used for the function's return value; all others are discarded. + '''Set function annotations using decorators. - :param \**kwargs: This is a mapping of argument names to annotations. + :param \*args: The first positional argument is used for the function's return value; all others are discarded. - Example:: + :param \**kwargs: This is a mapping of argument names to annotations. - @fannotate('This for the return value', a='Parameter a', b='Parameter b') - def x(a, b): - pass + Example:: - """ + @fannotate('This for the return value', a='Parameter a', b='Parameter b') + def x(a, b): + pass + ''' def _wrap(f): - if not hasattr(f, "__annotations__"): + if not hasattr(f, '__annotations__'): f.__annotations__ = {} if len(args) >= 1: - f.__annotations__["return"] = args[0] + f.__annotations__['return'] = args[0] f.__annotations__.update(kwargs) return f - return _wrap - def safe_unpack(seq, ln, fill=None): - """Safely unpack a sequence to length `ln`, without raising ValueError. Based on Lua's method of unpacking. Empty values will be filled in with `fill`, while any extra values will be cut off. + '''Safely unpack a sequence to length `ln`, without raising ValueError. Based on Lua's method of unpacking. Empty values will be filled in with `fill`, while any extra values will be cut off. - :param seq: The sequence to unpack. + :param seq: The sequence to unpack. - :param ln: The expected length of the sequence. + :param ln: The expected length of the sequence. - :param fill: The value to substitute if the sequence is too small. Defaults to ``None``. + :param fill: The value to substitute if the sequence is too small. Defaults to ``None``. - Example:: + Example:: - s = 'a:b' - a, b = safe_unpack(s.split(':'), 2) - # a = 'a' - # b = 'b' - s = 'a' - a, b = safe_unpack(s.split(':'), 2) - # a = 'a' - # b = None""" + s = 'a:b' + a, b = safe_unpack(s.split(':'), 2) + # a = 'a' + # b = 'b' + s = 'a' + a, b = safe_unpack(s.split(':'), 2) + # a = 'a' + # b = None''' if len(seq) > ln: return seq[:ln] elif len(seq) < ln: - return seq + type(seq)([fill] * (ln - len(seq))) + return seq + type(seq)([fill]*(ln-len(seq))) else: return seq - def assign(varname, value): - """Assign `value` to `varname` and return it. If `varname` is an attribute and the instance name it belongs to is not defined, a NameError is raised. - This can be used to emulate assignment as an expression. For example, this:: + '''Assign `value` to `varname` and return it. If `varname` is an attribute and the instance name it belongs to is not defined, a NameError is raised. + This can be used to emulate assignment as an expression. For example, this:: - if assign('x', 7): ... + if assign('x', 7): ... - is equilavent to this C code:: + is equilavent to this C code:: - if (x = 7) ... + if (x = 7) ... - .. warning:: + .. warning:: - When assigning an attribute, the instance it belongs to MUST be declared as global prior to the assignment. Otherwise, the assignment will not work. - """ + When assigning an attribute, the instance it belongs to MUST be declared as global prior to the assignment. Otherwise, the assignment will not work. + ''' fd = inspect.stack()[1][0].f_globals - if "." not in varname: + if '.' not in varname: fd[varname] = value else: - vsplit = list(map(str.strip, varname.split("."))) + vsplit = list(map(str.strip, varname.split('.'))) if vsplit[0] not in fd: - raise NameError("Unknown object: %s" % vsplit[0]) + raise NameError('Unknown object: %s'%vsplit[0]) base = fd[vsplit[0]] for x in vsplit[1:-1]: base = getattr(base, x) setattr(base, vsplit[-1], value) return value - def is_main(frame=1): "Return if the caller is main. Equilavent to ``__name__ == '__main__'``." - return inspect.stack()[frame][0].f_globals["__name__"] == "__main__" - + return inspect.stack()[frame][0].f_globals['__name__'] == '__main__' def _call_if_main(frame, f, args): - if is_main(frame): - return f(*args) - + if is_main(frame): return f(*args) -def call_if_main(f, *args): +def call_if_main(f,*args): "Call the `f` with `args` if the caller's module is main." - return _call_if_main(3, f, args) + return _call_if_main(3,f,args) - -def run_main(f, *args): +def run_main(f,*args): "Call `f` with the `args` and terminate the program with its return code if the caller's module is main." - sys.exit(_call_if_main(3, f, args)) - + sys.exit(_call_if_main(3,f,args)) def compare_and_swap(var, compare, new): "If `var` is equal to `compare`, set it to `new`." - if assign("v", inspect.stack()[1][0].f_globals)[var] == compare: - v[var] = new + if assign('v', inspect.stack()[1][0].f_globals)[var] == compare: + v[var] = new \ No newline at end of file diff --git a/skythought/skythought_evals/tasks/taco/taco.yaml b/skythought/skythought_evals/tasks/taco/taco.yaml index d411e72..961d0f4 100644 --- a/skythought/skythought_evals/tasks/taco/taco.yaml +++ b/skythought/skythought_evals/tasks/taco/taco.yaml @@ -1,6 +1,6 @@ handler: taco dataset_path: "BAAI/TACO" -dataset_source: ALL +dataset_source: MEDIUM dataset_split: train dataset_kwargs: trust_remote_code: true diff --git a/tests/tools/preprocessing.py b/tests/skythought_evals/preprocessing.py similarity index 89% rename from tests/tools/preprocessing.py rename to tests/skythought_evals/preprocessing.py index 24d64fc..5662aaa 100644 --- a/tests/tools/preprocessing.py +++ b/tests/skythought_evals/preprocessing.py @@ -1,10 +1,11 @@ import pytest -from skyevals.tasks import MMLUTaskHandler, TaskConfig +from skythought_evals.tasks import MMLUTaskHandler, TaskConfig + +SYSTEM_PROMPT = "Please answer the following question:" inputs = [ ( { - "is_stdin": False, "question": "What is the capital of France?", "choices": ["Paris", "London", "Berlin", "Madrid"], "answer": "0", @@ -21,7 +22,7 @@ ), MMLUTaskHandler, [ - {"role": "system", "content": "Please answer the following question:"}, + {"role": "system", "content": SYSTEM_PROMPT}, { "role": "user", "content": "Return your final response within \\boxed{}. What is the capital of France?\nAnswer Choices: (A) Paris (B) London (C) Berlin (D) Madrid", # noqa: E501 From 840006f808c0ea3923120a3ec079e3f61d7e67e6 Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Tue, 28 Jan 2025 02:24:31 +0000 Subject: [PATCH 28/47] test workflows Signed-off-by: SumanthRH --- .github/workflows.yaml | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 .github/workflows.yaml diff --git a/.github/workflows.yaml b/.github/workflows.yaml new file mode 100644 index 0000000..33ba8b7 --- /dev/null +++ b/.github/workflows.yaml @@ -0,0 +1,34 @@ +name: Skythought evals + +on: [push] + +# Cancel runs for previous commits on the same branch +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Set up Python + # This is the version of the action for setting up Python, not the Python version. + uses: actions/setup-python@v5 + with: + # Semantic version range syntax or exact version of a Python version + python-version: '3.10' + cache: 'pip' + # You can test your matrix by printing the current Python version + - name: Display Python version + run: python -c "import sys; print(sys.version)" + - name: Install dependencies + run: python -m pip install --upgrade pip setuptools wheel pre-commit + - name: Install skythought_evals + run: python -m pip install -e . + - name: Run pre-commit hooks + run: pre-commit run --all-files + - name: Run tests + run: python -m pytest tests/ From 4cdeab08ee5cd5370fac139afa07db080c8ac0cf Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Tue, 28 Jan 2025 02:26:08 +0000 Subject: [PATCH 29/47] x Signed-off-by: SumanthRH --- .github/{workflows.yaml => workflows/cpu_ci.yaml} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename .github/{workflows.yaml => workflows/cpu_ci.yaml} (100%) diff --git a/.github/workflows.yaml b/.github/workflows/cpu_ci.yaml similarity index 100% rename from .github/workflows.yaml rename to .github/workflows/cpu_ci.yaml From 8d564a33fd3ac89fefa655ab21ff09beb1e5536d Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Tue, 28 Jan 2025 02:27:08 +0000 Subject: [PATCH 30/47] x Signed-off-by: SumanthRH --- .github/workflows/cpu_ci.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/cpu_ci.yaml b/.github/workflows/cpu_ci.yaml index 33ba8b7..5e259fb 100644 --- a/.github/workflows/cpu_ci.yaml +++ b/.github/workflows/cpu_ci.yaml @@ -25,7 +25,7 @@ jobs: - name: Display Python version run: python -c "import sys; print(sys.version)" - name: Install dependencies - run: python -m pip install --upgrade pip setuptools wheel pre-commit + run: python -m pip install --upgrade pip setuptools wheel pre-commit pytest - name: Install skythought_evals run: python -m pip install -e . - name: Run pre-commit hooks From fc1087e50e60cb9f3846e1c3b053c3b04a6e3f02 Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Tue, 28 Jan 2025 02:35:00 +0000 Subject: [PATCH 31/47] x Signed-off-by: SumanthRH --- .github/workflows/cpu_ci.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/cpu_ci.yaml b/.github/workflows/cpu_ci.yaml index 5e259fb..1cd844d 100644 --- a/.github/workflows/cpu_ci.yaml +++ b/.github/workflows/cpu_ci.yaml @@ -29,6 +29,6 @@ jobs: - name: Install skythought_evals run: python -m pip install -e . - name: Run pre-commit hooks - run: pre-commit run --all-files + run: pre-commit run --all-files --config .pre-commit-config.yaml - name: Run tests run: python -m pytest tests/ From 6e2e9790b3ca1b632cb36ad6a5afee2ec938c6d8 Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Tue, 28 Jan 2025 02:40:53 +0000 Subject: [PATCH 32/47] it's time to fight the CI Signed-off-by: SumanthRH --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d2655bf..5bb08eb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,4 +13,4 @@ repos: rev: 24.10.0 hooks: - id: black - exclude: (^skythought/train|skythought_evals/tasks/taco/pyext2\.py)$ + exclude: (^skythought/train/.*|skythought_evals/tasks/taco/pyext2\.py)$ From 744911729c060e51a2c5c1ef57d1cac6294cd777 Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Tue, 28 Jan 2025 02:49:29 +0000 Subject: [PATCH 33/47] I might have won the fight: Signed-off-by: SumanthRH --- .github/workflows/cpu_ci.yaml | 3 -- tests/__init__.py | 0 tests/skythought_evals/__init__.py | 0 tests/skythought_evals/preprocessing.py | 51 ------------------------- 4 files changed, 54 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/skythought_evals/__init__.py delete mode 100644 tests/skythought_evals/preprocessing.py diff --git a/.github/workflows/cpu_ci.yaml b/.github/workflows/cpu_ci.yaml index 1cd844d..38718d6 100644 --- a/.github/workflows/cpu_ci.yaml +++ b/.github/workflows/cpu_ci.yaml @@ -21,9 +21,6 @@ jobs: # Semantic version range syntax or exact version of a Python version python-version: '3.10' cache: 'pip' - # You can test your matrix by printing the current Python version - - name: Display Python version - run: python -c "import sys; print(sys.version)" - name: Install dependencies run: python -m pip install --upgrade pip setuptools wheel pre-commit pytest - name: Install skythought_evals diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/skythought_evals/__init__.py b/tests/skythought_evals/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/skythought_evals/preprocessing.py b/tests/skythought_evals/preprocessing.py deleted file mode 100644 index 5662aaa..0000000 --- a/tests/skythought_evals/preprocessing.py +++ /dev/null @@ -1,51 +0,0 @@ -import pytest -from skythought_evals.tasks import MMLUTaskHandler, TaskConfig - -SYSTEM_PROMPT = "Please answer the following question:" - -inputs = [ - ( - { - "question": "What is the capital of France?", - "choices": ["Paris", "London", "Berlin", "Madrid"], - "answer": "0", - }, - TaskConfig( - handler="dummy", - dataset_path="dummy", - dataset_split="dummy", - question_key="question", - answer_key="answer", - templating_parameters={ - "template": "Return your final response within \\boxed{{}}. {prompt}" - }, - ), - MMLUTaskHandler, - [ - {"role": "system", "content": SYSTEM_PROMPT}, - { - "role": "user", - "content": "Return your final response within \\boxed{}. What is the capital of France?\nAnswer Choices: (A) Paris (B) London (C) Berlin (D) Madrid", # noqa: E501 - }, - ], - ), -] - - -@pytest.mark.parametrize("row,config,handler_cls,expected_conversation", inputs) -def test_make_conversations(row, config, handler_cls, expected_conversation): - - # Expected system prompt - system_prompt = "Please answer the following question:" - - # Initialize the handler - handler = handler_cls(config) - - # Expected conversation format - # expected input - # Call make_conversations - conversations = handler.make_conversations([row], system_prompt) - # Assert the conversation is as expected - assert conversations == [ - expected_conversation - ], f"Expected conversation {expected_conversation} but got {conversations}." From 04ead2a539883fcd6a1642dadcfac3da31d28767 Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Tue, 28 Jan 2025 02:57:40 +0000 Subject: [PATCH 34/47] CI please Signed-off-by: SumanthRH --- .gitignore | 1 - tests/skythought_evals/test_preprocessing.py | 51 ++++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) create mode 100644 tests/skythought_evals/test_preprocessing.py diff --git a/.gitignore b/.gitignore index 85b5c7c..3878c35 100644 --- a/.gitignore +++ b/.gitignore @@ -167,6 +167,5 @@ cython_debug/ .json token_usage/ -test_* run_all.sh diff --git a/tests/skythought_evals/test_preprocessing.py b/tests/skythought_evals/test_preprocessing.py new file mode 100644 index 0000000..5662aaa --- /dev/null +++ b/tests/skythought_evals/test_preprocessing.py @@ -0,0 +1,51 @@ +import pytest +from skythought_evals.tasks import MMLUTaskHandler, TaskConfig + +SYSTEM_PROMPT = "Please answer the following question:" + +inputs = [ + ( + { + "question": "What is the capital of France?", + "choices": ["Paris", "London", "Berlin", "Madrid"], + "answer": "0", + }, + TaskConfig( + handler="dummy", + dataset_path="dummy", + dataset_split="dummy", + question_key="question", + answer_key="answer", + templating_parameters={ + "template": "Return your final response within \\boxed{{}}. {prompt}" + }, + ), + MMLUTaskHandler, + [ + {"role": "system", "content": SYSTEM_PROMPT}, + { + "role": "user", + "content": "Return your final response within \\boxed{}. What is the capital of France?\nAnswer Choices: (A) Paris (B) London (C) Berlin (D) Madrid", # noqa: E501 + }, + ], + ), +] + + +@pytest.mark.parametrize("row,config,handler_cls,expected_conversation", inputs) +def test_make_conversations(row, config, handler_cls, expected_conversation): + + # Expected system prompt + system_prompt = "Please answer the following question:" + + # Initialize the handler + handler = handler_cls(config) + + # Expected conversation format + # expected input + # Call make_conversations + conversations = handler.make_conversations([row], system_prompt) + # Assert the conversation is as expected + assert conversations == [ + expected_conversation + ], f"Expected conversation {expected_conversation} but got {conversations}." From 84c96170a511eb0da5b6f5c6466107c719a9b13b Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Tue, 28 Jan 2025 03:08:00 +0000 Subject: [PATCH 35/47] set up permissions Signed-off-by: SumanthRH --- .github/workflows/cpu_ci.yaml | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/.github/workflows/cpu_ci.yaml b/.github/workflows/cpu_ci.yaml index 38718d6..f8f01cc 100644 --- a/.github/workflows/cpu_ci.yaml +++ b/.github/workflows/cpu_ci.yaml @@ -1,6 +1,16 @@ name: Skythought evals -on: [push] +on: + push: + branches: + - main # or your default branch + pull_request: + branches: + - main + +permissions: + checks: write # for status checks to appear + contents: read # Cancel runs for previous commits on the same branch concurrency: From e032b4792955be19787b2dd04b70c02708e4e160 Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Tue, 28 Jan 2025 04:31:21 +0000 Subject: [PATCH 36/47] test ci setup Signed-off-by: SumanthRH --- tests/skythought_evals/test_preprocessing.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/skythought_evals/test_preprocessing.py b/tests/skythought_evals/test_preprocessing.py index 5662aaa..2d8d1c1 100644 --- a/tests/skythought_evals/test_preprocessing.py +++ b/tests/skythought_evals/test_preprocessing.py @@ -41,8 +41,6 @@ def test_make_conversations(row, config, handler_cls, expected_conversation): # Initialize the handler handler = handler_cls(config) - # Expected conversation format - # expected input # Call make_conversations conversations = handler.make_conversations([row], system_prompt) # Assert the conversation is as expected From 7e99d6057bd1b065bc6b9015c55b9f88e44f4631 Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Tue, 28 Jan 2025 04:37:57 +0000 Subject: [PATCH 37/47] x Signed-off-by: SumanthRH --- .github/workflows/cpu_ci.yaml | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/.github/workflows/cpu_ci.yaml b/.github/workflows/cpu_ci.yaml index f8f01cc..980b6d3 100644 --- a/.github/workflows/cpu_ci.yaml +++ b/.github/workflows/cpu_ci.yaml @@ -1,12 +1,10 @@ name: Skythought evals -on: - push: - branches: - - main # or your default branch - pull_request: - branches: +on: + push: + branches: - main + pull_request: # all branches permissions: checks: write # for status checks to appear From 3f5ff024424b9442768bddae7900501a04d795e1 Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Tue, 28 Jan 2025 04:39:33 +0000 Subject: [PATCH 38/47] x Signed-off-by: SumanthRH --- .github/workflows/cpu_ci.yaml | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/.github/workflows/cpu_ci.yaml b/.github/workflows/cpu_ci.yaml index 980b6d3..300632f 100644 --- a/.github/workflows/cpu_ci.yaml +++ b/.github/workflows/cpu_ci.yaml @@ -1,10 +1,6 @@ name: Skythought evals -on: - push: - branches: - - main - pull_request: # all branches +on: [push, pull_request] permissions: checks: write # for status checks to appear From 79d12a2a65efecd4a2c25b76f4efc96807182d5a Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Tue, 28 Jan 2025 06:32:59 +0000 Subject: [PATCH 39/47] update to two workflows Signed-off-by: SumanthRH --- .github/workflows/{cpu_ci.yaml => cpu_ci.yml} | 25 ++++++++++++++++--- 1 file changed, 21 insertions(+), 4 deletions(-) rename .github/workflows/{cpu_ci.yaml => cpu_ci.yml} (60%) diff --git a/.github/workflows/cpu_ci.yaml b/.github/workflows/cpu_ci.yml similarity index 60% rename from .github/workflows/cpu_ci.yaml rename to .github/workflows/cpu_ci.yml index 300632f..c9e7df8 100644 --- a/.github/workflows/cpu_ci.yaml +++ b/.github/workflows/cpu_ci.yml @@ -1,6 +1,6 @@ name: Skythought evals -on: [push, pull_request] +on: [push, pull_request_target] permissions: checks: write # for status checks to appear @@ -12,8 +12,7 @@ concurrency: cancel-in-progress: true jobs: - build: - + check_code_quality: runs-on: ubuntu-latest steps: @@ -26,10 +25,28 @@ jobs: python-version: '3.10' cache: 'pip' - name: Install dependencies - run: python -m pip install --upgrade pip setuptools wheel pre-commit pytest + run: python -m pip install --upgrade pip setuptools wheel pre-commit - name: Install skythought_evals run: python -m pip install -e . - name: Run pre-commit hooks run: pre-commit run --all-files --config .pre-commit-config.yaml + + tests: + needs: check_code_quality + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Set up Python + # This is the version of the action for setting up Python, not the Python version. + uses: actions/setup-python@v5 + with: + # Semantic version range syntax or exact version of a Python version + python-version: '3.10' + cache: 'pip' + - name: Install dependencies + run: python -m pip install --upgrade pip setuptools wheel pre-commit pytest + - name: Install skythought_evals + run: python -m pip install -e . - name: Run tests run: python -m pytest tests/ From c8a8d6384567ee3ba776d70f0af2eb81e96649bf Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Tue, 28 Jan 2025 19:22:43 +0000 Subject: [PATCH 40/47] update to later vllm; needed for some tokenizer_revision fixes Signed-off-by: SumanthRH --- skythought/skythought_evals/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skythought/skythought_evals/requirements.txt b/skythought/skythought_evals/requirements.txt index 4283fe2..b218ab6 100644 --- a/skythought/skythought_evals/requirements.txt +++ b/skythought/skythought_evals/requirements.txt @@ -1,4 +1,4 @@ -vllm==0.6.2 +vllm==0.7.0 pyext word2number scipy From aa871242230a181bcd8646eb1616858fe64ca5cf Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Tue, 28 Jan 2025 22:50:28 +0000 Subject: [PATCH 41/47] x Signed-off-by: SumanthRH --- skythought/skythought_evals/requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/skythought/skythought_evals/requirements.txt b/skythought/skythought_evals/requirements.txt index b218ab6..5861e92 100644 --- a/skythought/skythought_evals/requirements.txt +++ b/skythought/skythought_evals/requirements.txt @@ -3,4 +3,5 @@ pyext word2number scipy datasets -latex2sympy2 \ No newline at end of file +latex2sympy2 +pydantic \ No newline at end of file From eab138ae1f7490ccb725615a4277327e9fac5850 Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Tue, 28 Jan 2025 22:51:24 +0000 Subject: [PATCH 42/47] x Signed-off-by: SumanthRH --- format.sh | 5 ----- 1 file changed, 5 deletions(-) diff --git a/format.sh b/format.sh index afdeb4a..d0fef1a 100644 --- a/format.sh +++ b/format.sh @@ -1,11 +1,6 @@ set -e -# Get tools directory path relative to git root -GIT_ROOT=$(git rev-parse --show-toplevel) -TOOLS_RELATIVE=skythought/tools -TOOLS_DIR=$GIT_ROOT/$TOOLS_RELATIVE - if command -v uv >/dev/null 2>&1; then uv pip install -q pre-commit else From 07c21f90705084327355e4b0ebbac761617d6f46 Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Fri, 31 Jan 2025 23:57:46 +0000 Subject: [PATCH 43/47] small update Signed-off-by: SumanthRH --- .../skythought_evals/inference_and_check.py | 19 +++++++++++++++---- skythought/skythought_evals/tasks/__init__.py | 2 +- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/skythought/skythought_evals/inference_and_check.py b/skythought/skythought_evals/inference_and_check.py index 7c3f10e..d5893f2 100644 --- a/skythought/skythought_evals/inference_and_check.py +++ b/skythought/skythought_evals/inference_and_check.py @@ -7,7 +7,12 @@ import numpy as np from openai import OpenAI -from skythought_evals.tasks import TASK_HANDLER_MAP, NUMINATaskHandler, TaskHandler +from skythought_evals.tasks import ( + TASK_HANDLER_MAP, + NUMINATaskHandler, + TaskConfig, + TaskHandler, +) from skythought_evals.tasks.task_util import get_tasks from skythought_evals.util.model_utils import MODEL_TO_NAME, SYSTEM_PROMPT from tqdm import tqdm @@ -490,9 +495,15 @@ def main(): ) args = parser.parse_args() - handler_cls: TaskHandler = TASK_HANDLER_MAP[args.task] - config_path = TASK_NAMES_TO_YAML[args.task] - handler = handler_cls.from_config_path(config_path) + if args.task not in TASK_NAMES_TO_YAML: + raise ValueError( + f"Task {args.task} not found. Should be one of {TASK_NAMES_TO_YAML.keys()}" + ) + + task_config = TaskConfig.from_yaml(TASK_NAMES_TO_YAML[args.task]) + handler_name = task_config.handler + handler_cls = TASK_HANDLER_MAP[handler_name] + handler = handler_cls(task_config) temperatures = [1] if args.model.startswith("openai/o1") else args.temperatures diff --git a/skythought/skythought_evals/tasks/__init__.py b/skythought/skythought_evals/tasks/__init__.py index a9d2882..d03b60d 100644 --- a/skythought/skythought_evals/tasks/__init__.py +++ b/skythought/skythought_evals/tasks/__init__.py @@ -27,7 +27,7 @@ "arc_c": ARCChallengeTaskHandler, "amc23": AMC23TaskHandler, "minervamath": MinervaMathTaskHandler, - "olympiadbench_math_en": OlympiadBenchMathTaskHandler, + "olympiadbench_math": OlympiadBenchMathTaskHandler, } __all__ = [ From 3d6942f19cf9a0cf36a7229a1e6c3352d02b1f93 Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Sat, 1 Feb 2025 01:06:35 +0000 Subject: [PATCH 44/47] reworking args Signed-off-by: SumanthRH --- skythought/skythought_evals/README.md | 9 ++-- skythought/skythought_evals/eval.py | 6 +-- .../skythought_evals/inference_and_check.py | 47 ++++++++++++------- .../tasks/aime/aime_handler.py | 4 +- .../tasks/amc23/amc23_handler.py | 4 +- .../skythought_evals/tasks/apps/apps.yaml | 2 +- .../tasks/apps/apps_handler.py | 12 ++--- .../skythought_evals/tasks/arc/arc_c.yaml | 2 +- .../skythought_evals/tasks/arc/arc_handler.py | 4 +- skythought/skythought_evals/tasks/base.py | 14 ++---- .../tasks/gpqa_diamond/gpqa_diamond.yaml | 2 +- .../gpqa_diamond/gpqa_diamond_handler.py | 4 +- .../skythought_evals/tasks/gsm8k/gsm8k.yaml | 2 +- .../tasks/gsm8k/gsm8k_handler.py | 4 +- .../tasks/livecodebench/livecodebench.yaml | 2 +- .../livecodebench/livecodebench_handler.py | 12 ++--- .../skythought_evals/tasks/math/math500.yaml | 2 +- .../tasks/math/math_handler.py | 4 +- .../tasks/minervamath/minervamath.yaml | 2 +- .../skythought_evals/tasks/mmlu/mmlu.yaml | 2 +- .../tasks/mmlu/mmlu_handler.py | 8 ++-- .../skythought_evals/tasks/mmlu/mmlu_pro.yaml | 2 +- .../skythought_evals/tasks/numina/numina.yaml | 9 ++-- .../tasks/numina/numina_handler.py | 31 ++++++++---- .../olympiadbench/olympiadbench_math_en.yaml | 2 +- .../skythought_evals/tasks/taco/taco.yaml | 2 +- .../tasks/taco/taco_handler.py | 12 ++--- skythought/skythought_evals/util/common.py | 13 +++++ 28 files changed, 130 insertions(+), 89 deletions(-) diff --git a/skythought/skythought_evals/README.md b/skythought/skythought_evals/README.md index 26d2f15..a18e759 100644 --- a/skythought/skythought_evals/README.md +++ b/skythought/skythought_evals/README.md @@ -69,18 +69,17 @@ Currently we support distill and reject sampling from various self-hosted models ```shell python inference_and_check.py --task apps --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split test --source all --result-dir $SKYT_HOME/data -python inference_and_check.py --task taco --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source MEDIUM --filter-difficulty --result-dir $SKYT_HOME/data +python inference_and_check.py --task taco --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --filter-difficulty MEDIUM --result-dir $SKYT_HOME/data python inference_and_check.py --task taco --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split test --source all --result-dir $SKYT_HOME/data -python inference_and_check.py --task numina --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source math --filter-difficulty --result-dir $SKYT_HOME/data --math_difficulty_lower_bound 4 --math_difficulty_upper_bound 9 +python inference_and_check.py --task numina --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source math --filter-difficulty true --result-dir $SKYT_HOME/data --math-difficulty-lower-bound 4 --math-difficulty-upper-bound 9 -python inference_and_check.py --task numina --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source amc_aime --filter-difficulty --result-dir $SKYT_HOME/data --math_difficulty_lower_bound 1 --math_difficulty_upper_bound 9 +python inference_and_check.py --task numina --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source amc_aime --filter-difficulty true --result-dir $SKYT_HOME/data --math-difficulty-lower-bound 1 --math-difficulty-upper-bound 9 -python inference_and_check.py --task numina --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source olympiads --end 20000 --filter-difficulty --result-dir $SKYT_HOME/data --math_difficulty_lower_bound 9 --math_difficulty_upper_bound 9 +python inference_and_check.py --task numina --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --end 20000 --filter-difficulty olympiads --result-dir $SKYT_HOME/data --math-difficulty-lower-bound 9 --math-difficulty-upper-bound 9 ``` - #### Best-of-N Inference and Check ```bash python inference_and_check.py --task math500 --model Qwen/Qwen2-7B-Instruct --tp 4 --max_tokens 4096 --split test --result-dir ./ --inference --temperatures 0.7 --n 64 diff --git a/skythought/skythought_evals/eval.py b/skythought/skythought_evals/eval.py index b751aa9..411a697 100644 --- a/skythought/skythought_evals/eval.py +++ b/skythought/skythought_evals/eval.py @@ -102,9 +102,9 @@ def main(): ] command.extend(temperatures) # Add temperatures as separate arguments - if args.filter_difficulty: - command.append("--filter-difficulty") - command.append(args.filter_difficulty) + if args.difficulty: + command.append("--difficulty") + command.append(args.difficulty) print(f"Running eval {eval_name} with command {command}") all_logs += f"\nRunning eval: {eval_name} with command {command}\n" diff --git a/skythought/skythought_evals/inference_and_check.py b/skythought/skythought_evals/inference_and_check.py index d5893f2..4636a25 100644 --- a/skythought/skythought_evals/inference_and_check.py +++ b/skythought/skythought_evals/inference_and_check.py @@ -14,6 +14,7 @@ TaskHandler, ) from skythought_evals.tasks.task_util import get_tasks +from skythought_evals.util.common import set_seed from skythought_evals.util.model_utils import MODEL_TO_NAME, SYSTEM_PROMPT from tqdm import tqdm from vllm import LLM, SamplingParams @@ -70,8 +71,8 @@ def perform_inference_and_check( args.start, args.end, split=args.split, - source=args.source, - filter_difficulty=args.filter_difficulty, + subset=args.subset, + difficulty=args.difficulty, args=args, ) remaining_data = handler.process_remaining_data(train_data, results) @@ -212,8 +213,8 @@ def perform_check(handler: TaskHandler, temperatures, result_file, args): args.start, args.end, split=args.split, - source=args.source, - filter_difficulty=args.filter_difficulty, + subset=args.subset, + difficulty=args.difficulty, args=args, ) remaining_data = handler.process_remaining_data(train_data, {}) @@ -313,8 +314,8 @@ def perform_inference_and_save( args.start, args.end, split=args.split, - source=args.source, - filter_difficulty=args.filter_difficulty, + subset=args.subset, + difficulty=args.difficulty, args=args, ) remaining_data = handler.process_remaining_data(train_data, results) @@ -453,14 +454,24 @@ def main(): default=None, help="Split to use for the dataset (e.g., train, test).", ) - parser.add_argument("--source", type=str, help="Source for the dataset.") + parser.add_argument("--subset", type=str, help="Subset for the dataset.") parser.add_argument("--start", type=int, default=0, help="Start index.") parser.add_argument("--end", type=int, default=-1, help="End index.") parser.add_argument( - "--filter-difficulty", + "--difficulty", type=str, default=None, - help="Optional filter difficulty. Options: 'easy', 'medium', 'hard'.", + help="Difficulty level. Example: 'easy', 'medium', 'hard'.", + ) + parser.add_argument( + "--filter-difficulty", + action="store_true", + help="Optional filter difficulty, used for NUMINA.", + ) + parser.add_argument( + "--source", + type=str, + help="Source column filter for the dataset, used for NUMINA.", ) parser.add_argument( "--result-dir", type=str, default="./", help="Result dir to save files." @@ -493,7 +504,10 @@ def main(): parser.add_argument( "--n", type=int, default=1, help="Number of samples generated per problem." ) + parser.add_argument("--seed", type=int, default=41, help="Random seed.") + args = parser.parse_args() + set_seed(args.seed) if args.task not in TASK_NAMES_TO_YAML: raise ValueError( @@ -516,9 +530,10 @@ def main(): # TODO: this can be cleaned up by allowing user override for any task_config with optional task_args # Currently kept here for consistency with old code args.split = args.split if args.split else handler.task_config.dataset_split - args.source = args.source if args.source else handler.task_config.dataset_source - if not args.filter_difficulty and handler.task_config.preprocess_config: - args.filter_difficulty = handler.task_config.preprocess_config.difficulty + args.subset = args.subset if args.subset else handler.task_config.dataset_subset + if not args.difficulty and "difficulty" in handler.task_config.preprocess_config: + args.difficulty = handler.task_config.preprocess_config["difficulty"] + # create result dir if not exists if args.result_dir and not os.path.exists(args.result_dir): os.makedirs(args.result_dir) @@ -528,12 +543,12 @@ def main(): ): result_file = os.path.join( args.result_dir, - f"{MODEL_TO_NAME[args.model]}_{args.task}_{args.split}_{args.source}_{args.filter_difficulty}_{args.start}_{args.end}_{args.math_difficulty_lower_bound}_{args.math_difficulty_upper_bound}.json", + f"{MODEL_TO_NAME[args.model]}_{args.task}_{args.split}_{args.subset}_{args.filter_difficulty}_{args.start}_{args.end}_{args.math_difficulty_lower_bound}_{args.math_difficulty_upper_bound}.json", ) else: result_file = os.path.join( args.result_dir, - f"{MODEL_TO_NAME[args.model]}_{args.task}_{args.split}_{args.source}_{args.filter_difficulty}_{args.start}_{args.end}.json", + f"{MODEL_TO_NAME[args.model]}_{args.task}_{args.split}_{args.subset}_{args.filter_difficulty}_{args.start}_{args.end}.json", ) if args.check: @@ -543,11 +558,11 @@ def main(): or args.math_difficulty_upper_bound is not None ): converted_file = ( - f"{args.result_dir}/converted_{MODEL_TO_NAME[args.model]}_{args.task}_{args.split}_{args.source}_{args.filter_difficulty}_{args.start}_{args.end}" + f"{args.result_dir}/converted_{MODEL_TO_NAME[args.model]}_{args.task}_{args.split}_{args.subset}_{args.filter_difficulty}_{args.start}_{args.end}" + f"_{args.math_difficulty_lower_bound}_{args.math_difficulty_upper_bound}.json" ) else: - converted_file = f"{args.result_dir}/converted_{MODEL_TO_NAME[args.model]}_{args.task}_{args.split}_{args.source}_{args.filter_difficulty}" + converted_file = f"{args.result_dir}/converted_{MODEL_TO_NAME[args.model]}_{args.task}_{args.split}_{args.subset}_{args.filter_difficulty}" f"_{args.start}_{args.end}.json" if os.path.exists(converted_file): result_file = converted_file diff --git a/skythought/skythought_evals/tasks/aime/aime_handler.py b/skythought/skythought_evals/tasks/aime/aime_handler.py index 0fd8b43..9e0756c 100644 --- a/skythought/skythought_evals/tasks/aime/aime_handler.py +++ b/skythought/skythought_evals/tasks/aime/aime_handler.py @@ -29,8 +29,8 @@ def make_conversations(self, data, system_prompt, model=None): return conversations def load_and_filter_dataset( - self, start, end, split=None, source=None, filter_difficulty=None, args=None + self, start, end, split=None, subset=None, difficulty=None, args=None ): - train_data = self.load_dataset(source=source, split=split).to_pandas() + train_data = self.load_dataset(subset=subset, split=split).to_pandas() filtered_data = train_data[train_data["url"].str.contains("2024", na=False)] return filtered_data.iloc[start:end] if end > 0 else filtered_data.iloc[start:] diff --git a/skythought/skythought_evals/tasks/amc23/amc23_handler.py b/skythought/skythought_evals/tasks/amc23/amc23_handler.py index 5213559..46d29bc 100644 --- a/skythought/skythought_evals/tasks/amc23/amc23_handler.py +++ b/skythought/skythought_evals/tasks/amc23/amc23_handler.py @@ -3,8 +3,8 @@ class AMC23TaskHandler(MathTaskHandler): def load_and_filter_dataset( - self, start, end, split=None, source=None, filter_difficulty=None, args=None + self, start, end, split=None, subset=None, difficulty=None, args=None ): - train_data = self.load_dataset(source=source, split=split).to_pandas() + train_data = self.load_dataset(subset=subset, split=split).to_pandas() filtered_data = train_data[train_data["url"].str.contains("2023", na=False)] return filtered_data.iloc[start:end] if end > 0 else filtered_data.iloc[start:] diff --git a/skythought/skythought_evals/tasks/apps/apps.yaml b/skythought/skythought_evals/tasks/apps/apps.yaml index eeb1569..0ff7e01 100644 --- a/skythought/skythought_evals/tasks/apps/apps.yaml +++ b/skythought/skythought_evals/tasks/apps/apps.yaml @@ -1,6 +1,6 @@ handler: apps dataset_path: codeparrot/apps -dataset_source: all +dataset_subset: all dataset_kwargs: trust_remote_code: true dataset_split: test diff --git a/skythought/skythought_evals/tasks/apps/apps_handler.py b/skythought/skythought_evals/tasks/apps/apps_handler.py index 63c6b39..563fcde 100644 --- a/skythought/skythought_evals/tasks/apps/apps_handler.py +++ b/skythought/skythought_evals/tasks/apps/apps_handler.py @@ -99,14 +99,14 @@ def make_conversations(self, data, system_prompt, model=None): return conversations def load_and_filter_dataset( - self, start, end, split=None, source=None, filter_difficulty=None, args=None + self, start, end, split=None, subset=None, difficulty=None, args=None ): - train_data = self.load_dataset(source=source, split=split).to_pandas() - if filter_difficulty or self.task_config.preprocess_config.difficulty: + train_data = self.load_dataset(subset=subset, split=split).to_pandas() + if difficulty or "difficulty" in self.task_config.preprocess_config: difficulty = ( - self.task_config.preprocess_config.difficulty - if not filter_difficulty - else filter_difficulty + self.task_config.preprocess_config["difficulty"] + if not difficulty + else difficulty ) train_data = train_data.filter(lambda x: x["difficulty"] == difficulty) diff --git a/skythought/skythought_evals/tasks/arc/arc_c.yaml b/skythought/skythought_evals/tasks/arc/arc_c.yaml index 09f83e9..bad44fd 100644 --- a/skythought/skythought_evals/tasks/arc/arc_c.yaml +++ b/skythought/skythought_evals/tasks/arc/arc_c.yaml @@ -1,6 +1,6 @@ handler: arc_c dataset_path: allenai/ai2_arc -dataset_source: ARC-Challenge +dataset_subset: ARC-Challenge dataset_split: train question_key: question answer_key: answerKey diff --git a/skythought/skythought_evals/tasks/arc/arc_handler.py b/skythought/skythought_evals/tasks/arc/arc_handler.py index 221c94d..9c1e48d 100644 --- a/skythought/skythought_evals/tasks/arc/arc_handler.py +++ b/skythought/skythought_evals/tasks/arc/arc_handler.py @@ -69,9 +69,9 @@ def make_conversations(self, data, system_prompt, model=None): return conversations def load_and_filter_dataset( - self, start, end, split=None, source=None, filter_difficulty=None, args=None + self, start, end, split=None, subset=None, difficulty=None, args=None ): - train_data = self.load_dataset(source=source, split=split).to_pandas() + train_data = self.load_dataset(subset=subset, split=split).to_pandas() return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] def process_remaining_data(self, train_data, results): diff --git a/skythought/skythought_evals/tasks/base.py b/skythought/skythought_evals/tasks/base.py index b91e8e4..7ebfdb9 100644 --- a/skythought/skythought_evals/tasks/base.py +++ b/skythought/skythought_evals/tasks/base.py @@ -8,14 +8,10 @@ from pydantic import BaseModel, Field -class PreprocessConfig(BaseModel): - difficulty: str - - class TaskConfig(BaseModel): handler: str dataset_path: str - dataset_source: Optional[str] = None + dataset_subset: Optional[str] = None dataset_split: str dataset_kwargs: Dict[str, Any] = Field(default_factory=dict) question_key: str @@ -26,7 +22,7 @@ class TaskConfig(BaseModel): fewshot_config: List[Dict[str, Any]] = Field(default_factory=list) num_fewshot: int = 0 - preprocess_config: Optional[PreprocessConfig] = None + preprocess_config: Dict[str, Any] = Field(default_factory=dict) @classmethod def from_yaml(cls, yaml_file_path) -> "TaskConfig": @@ -65,17 +61,17 @@ def load_existing_results(self, result_file): records = json.load(f) return records - def load_dataset(self, source=None, split=None, **kwargs) -> HFDataset: + def load_dataset(self, subset=None, split=None, **kwargs) -> HFDataset: dataset = load_dataset( path=self.task_config.dataset_path, - name=source if source else self.task_config.dataset_source, + name=subset if subset else self.task_config.dataset_subset, split=split if split else self.task_config.dataset_split, **self.task_config.dataset_kwargs ) return dataset def load_and_filter_dataset( - self, start, end, split="train", source=None, filter_difficulty=None, args=None + self, start, end, split=None, subset=None, difficulty=None, args=None ): raise NotImplementedError("Subclasses should implement this method.") diff --git a/skythought/skythought_evals/tasks/gpqa_diamond/gpqa_diamond.yaml b/skythought/skythought_evals/tasks/gpqa_diamond/gpqa_diamond.yaml index 940d960..963bf6f 100644 --- a/skythought/skythought_evals/tasks/gpqa_diamond/gpqa_diamond.yaml +++ b/skythought/skythought_evals/tasks/gpqa_diamond/gpqa_diamond.yaml @@ -1,6 +1,6 @@ handler: gpqa_diamond dataset_path: Idavidrein/gpqa -dataset_source: gpqa_diamond +dataset_subset: gpqa_diamond dataset_split: train question_key: Question answer_key: Answer diff --git a/skythought/skythought_evals/tasks/gpqa_diamond/gpqa_diamond_handler.py b/skythought/skythought_evals/tasks/gpqa_diamond/gpqa_diamond_handler.py index fb0d5ef..8bd5c28 100644 --- a/skythought/skythought_evals/tasks/gpqa_diamond/gpqa_diamond_handler.py +++ b/skythought/skythought_evals/tasks/gpqa_diamond/gpqa_diamond_handler.py @@ -82,9 +82,9 @@ def make_conversations(self, data, system_prompt, model=None): return conversations def load_and_filter_dataset( - self, start, end, split=None, source=None, filter_difficulty=None, args=None + self, start, end, split=None, subset=None, difficulty=None, args=None ): - train_data = self.load_dataset(source=source, split=split).to_pandas() + train_data = self.load_dataset(subset=subset, split=split).to_pandas() return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] def process_remaining_data(self, train_data, results): diff --git a/skythought/skythought_evals/tasks/gsm8k/gsm8k.yaml b/skythought/skythought_evals/tasks/gsm8k/gsm8k.yaml index 2ef5012..58159f9 100644 --- a/skythought/skythought_evals/tasks/gsm8k/gsm8k.yaml +++ b/skythought/skythought_evals/tasks/gsm8k/gsm8k.yaml @@ -1,6 +1,6 @@ handler: gsm8k dataset_path: "openai/gsm8k" -dataset_source: main +dataset_subset: main dataset_split: test question_key: question answer_key: answer diff --git a/skythought/skythought_evals/tasks/gsm8k/gsm8k_handler.py b/skythought/skythought_evals/tasks/gsm8k/gsm8k_handler.py index f913b51..5bde1ee 100644 --- a/skythought/skythought_evals/tasks/gsm8k/gsm8k_handler.py +++ b/skythought/skythought_evals/tasks/gsm8k/gsm8k_handler.py @@ -54,9 +54,9 @@ def make_conversations(self, data, system_prompt, model=None): return conversations def load_and_filter_dataset( - self, start, end, split=None, source=None, filter_difficulty=None, args=None + self, start, end, split=None, subset=None, difficulty=None, args=None ): - train_data = self.load_dataset(source=source, split=split).to_pandas() + train_data = self.load_dataset(subset=subset, split=split).to_pandas() return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:] def process_remaining_data(self, train_data, results): diff --git a/skythought/skythought_evals/tasks/livecodebench/livecodebench.yaml b/skythought/skythought_evals/tasks/livecodebench/livecodebench.yaml index ec060be..28fd7f2 100644 --- a/skythought/skythought_evals/tasks/livecodebench/livecodebench.yaml +++ b/skythought/skythought_evals/tasks/livecodebench/livecodebench.yaml @@ -1,6 +1,6 @@ handler: livecodebench dataset_path: "livecodebench/code_generation_lite" # repo ID in huggingface -dataset_source: null +dataset_subset: null dataset_split: test dataset_kwargs: version_tag: release_v2 diff --git a/skythought/skythought_evals/tasks/livecodebench/livecodebench_handler.py b/skythought/skythought_evals/tasks/livecodebench/livecodebench_handler.py index 8f99ce2..0f441de 100644 --- a/skythought/skythought_evals/tasks/livecodebench/livecodebench_handler.py +++ b/skythought/skythought_evals/tasks/livecodebench/livecodebench_handler.py @@ -98,15 +98,15 @@ def make_conversations(self, data, system_prompt, model=None): return conversations def load_and_filter_dataset( - self, start, end, split=None, source=None, filter_difficulty=None, args=None + self, start, end, split=None, subset=None, difficulty=None, args=None ): - dataset: HFDataset = self.load_dataset(source=source, split=split) + dataset: HFDataset = self.load_dataset(subset=subset, split=split) # Filter by CLI or config - if filter_difficulty or self.task_config.preprocess_config.difficulty: + if difficulty or "difficulty" in self.task_config.preprocess_config: difficulty = ( - filter_difficulty - if filter_difficulty - else self.task_config.preprocess_config.difficulty + difficulty + if difficulty + else self.task_config.preprocess_config["difficulty"] ) dataset = dataset.filter( lambda example: example["difficulty"] == difficulty diff --git a/skythought/skythought_evals/tasks/math/math500.yaml b/skythought/skythought_evals/tasks/math/math500.yaml index 43c0e82..135cdb9 100644 --- a/skythought/skythought_evals/tasks/math/math500.yaml +++ b/skythought/skythought_evals/tasks/math/math500.yaml @@ -1,6 +1,6 @@ handler: math dataset_path: "qq8933/MATH500" # repo ID in huggingface -dataset_source: null # which subset on huggingface +dataset_subset: null # which subset on huggingface question_key: problem answer_key: answer dataset_split: test diff --git a/skythought/skythought_evals/tasks/math/math_handler.py b/skythought/skythought_evals/tasks/math/math_handler.py index ca948b2..b07c829 100644 --- a/skythought/skythought_evals/tasks/math/math_handler.py +++ b/skythought/skythought_evals/tasks/math/math_handler.py @@ -56,7 +56,7 @@ def process_remaining_data(self, train_data, results): ] def load_and_filter_dataset( - self, start, end, split=None, source=None, filter_difficulty=None, args=None + self, start, end, split=None, subset=None, difficulty=None, args=None ): - dataset = self.load_dataset(source=source, split=split).to_pandas() + dataset = self.load_dataset(subset=subset, split=split).to_pandas() return dataset.iloc[start:end] if end > 0 else dataset.iloc[start:] diff --git a/skythought/skythought_evals/tasks/minervamath/minervamath.yaml b/skythought/skythought_evals/tasks/minervamath/minervamath.yaml index 85ba7aa..d7f707b 100644 --- a/skythought/skythought_evals/tasks/minervamath/minervamath.yaml +++ b/skythought/skythought_evals/tasks/minervamath/minervamath.yaml @@ -1,6 +1,6 @@ handler: math dataset_path: "svc-huggingface/minerva-math" # repo ID in huggingface -dataset_source: null # which subset on huggingface +dataset_subset: null # which subset on huggingface question_key: problem answer_key: solution dataset_split: test diff --git a/skythought/skythought_evals/tasks/mmlu/mmlu.yaml b/skythought/skythought_evals/tasks/mmlu/mmlu.yaml index ad98fd5..cf0b937 100644 --- a/skythought/skythought_evals/tasks/mmlu/mmlu.yaml +++ b/skythought/skythought_evals/tasks/mmlu/mmlu.yaml @@ -1,6 +1,6 @@ handler: mmlu dataset_path: cais/mmlu -dataset_source: all +dataset_subset: all dataset_split: test question_key: question answer_key: answer diff --git a/skythought/skythought_evals/tasks/mmlu/mmlu_handler.py b/skythought/skythought_evals/tasks/mmlu/mmlu_handler.py index b69174c..3ca7bc3 100644 --- a/skythought/skythought_evals/tasks/mmlu/mmlu_handler.py +++ b/skythought/skythought_evals/tasks/mmlu/mmlu_handler.py @@ -65,9 +65,9 @@ def process_remaining_data(self, train_data, results): ] def load_and_filter_dataset( - self, start, end, split=None, source=None, filter_difficulty=None, args=None + self, start, end, split=None, subset=None, difficulty=None, args=None ): - dataset = self.load_dataset(source=source, split=split).to_pandas() + dataset = self.load_dataset(subset=subset, split=split).to_pandas() return dataset.iloc[start:end] if end > 0 else dataset.iloc[start:] @@ -109,7 +109,7 @@ def get_multiple_choice_answers(self, problem): return f"Answer Choices: {options}" def load_and_filter_dataset( - self, start, end, split=None, source=None, filter_difficulty=None, args=None + self, start, end, split=None, subset=None, difficulty=None, args=None ): - dataset = self.load_dataset(source=source, split=split).to_pandas() + dataset = self.load_dataset(subset=subset, split=split).to_pandas() return dataset.iloc[start:end] if end > 0 else dataset.iloc[start:] diff --git a/skythought/skythought_evals/tasks/mmlu/mmlu_pro.yaml b/skythought/skythought_evals/tasks/mmlu/mmlu_pro.yaml index 4b88e92..0c48537 100644 --- a/skythought/skythought_evals/tasks/mmlu/mmlu_pro.yaml +++ b/skythought/skythought_evals/tasks/mmlu/mmlu_pro.yaml @@ -1,6 +1,6 @@ handler: mmlu_pro dataset_path: TIGER-Lab/MMLU-Pro -dataset_source: default +dataset_subset: default dataset_split: test question_key: question answer_key: answer diff --git a/skythought/skythought_evals/tasks/numina/numina.yaml b/skythought/skythought_evals/tasks/numina/numina.yaml index 7f4c8f8..9130c78 100644 --- a/skythought/skythought_evals/tasks/numina/numina.yaml +++ b/skythought/skythought_evals/tasks/numina/numina.yaml @@ -1,11 +1,14 @@ handler: numina dataset_path: "AI-MO/NuminaMath-CoT" -dataset_source: default +dataset_subset: null dataset_split: train question_key: problem answer_key: solution templating_parameters: template: "Return your final response within \\boxed{{}}. {prompt}" # Optionally, you can filter the dataset by difficulty -# preprocess_config: -# difficulty: easy +preprocess_config: + filter_difficulty: true + math_difficulty_lower_bound: 4 + math_difficulty_upper_bound: 9 + source: math diff --git a/skythought/skythought_evals/tasks/numina/numina_handler.py b/skythought/skythought_evals/tasks/numina/numina_handler.py index f2842e0..658021a 100644 --- a/skythought/skythought_evals/tasks/numina/numina_handler.py +++ b/skythought/skythought_evals/tasks/numina/numina_handler.py @@ -72,21 +72,36 @@ def make_conversations(self, data, system_prompt, model=None): return conversations def load_and_filter_dataset( - self, start, end, split="train", source=None, filter_difficulty=None, args=None + self, start, end, split=None, subset=None, difficulty=None, args=None ): - dataset = self.load_dataset(source=source, split=split).to_pandas() + dataset = self.load_dataset(subset=subset, split=split).to_pandas() + + if args.source: + dataset = dataset[dataset["source"] == args.source] dataset = dataset.iloc[start:end] if end > 0 else dataset.iloc[start:] dataset = dataset[dataset["solution"].str.contains("boxed", na=False)] - if filter_difficulty: - diff_dict = self.get_difficulty_dict(source, start, end) + + if ( + args.filter_difficulty + or "filter_difficulty" in self.task_config.preprocess_config + ): + lower_bound = ( + args.math_difficulty_lower_bound + if args.filter_difficulty + else self.task_config.preprocess_config["math_difficulty_lower_bound"] + ) + upper_bound = ( + args.math_difficulty_upper_bound + if args.filter_difficulty + else self.task_config.preprocess_config["math_difficulty_upper_bound"] + ) + diff_dict = self.get_difficulty_dict(args.source, start, end) dataset = dataset[ dataset["problem"] .map(diff_dict) - .apply( - lambda x: x >= args.math_difficulty_lower_bound - and x <= args.math_difficulty_upper_bound - ) + .apply(lambda x: x >= lower_bound and x <= upper_bound) ] + return dataset def process_remaining_data(self, train_data, results): diff --git a/skythought/skythought_evals/tasks/olympiadbench/olympiadbench_math_en.yaml b/skythought/skythought_evals/tasks/olympiadbench/olympiadbench_math_en.yaml index b532ed5..3311a0e 100644 --- a/skythought/skythought_evals/tasks/olympiadbench/olympiadbench_math_en.yaml +++ b/skythought/skythought_evals/tasks/olympiadbench/olympiadbench_math_en.yaml @@ -1,6 +1,6 @@ handler: olympiadbench_math dataset_path: Hothan/OlympiadBench -dataset_source: OE_TO_maths_en_COMP +dataset_subset: OE_TO_maths_en_COMP dataset_split: train question_key: question answer_key: final_answer diff --git a/skythought/skythought_evals/tasks/taco/taco.yaml b/skythought/skythought_evals/tasks/taco/taco.yaml index 961d0f4..f7060e4 100644 --- a/skythought/skythought_evals/tasks/taco/taco.yaml +++ b/skythought/skythought_evals/tasks/taco/taco.yaml @@ -1,6 +1,6 @@ handler: taco dataset_path: "BAAI/TACO" -dataset_source: MEDIUM +dataset_subset: MEDIUM dataset_split: train dataset_kwargs: trust_remote_code: true diff --git a/skythought/skythought_evals/tasks/taco/taco_handler.py b/skythought/skythought_evals/tasks/taco/taco_handler.py index 48b36a0..cd53a63 100644 --- a/skythought/skythought_evals/tasks/taco/taco_handler.py +++ b/skythought/skythought_evals/tasks/taco/taco_handler.py @@ -107,14 +107,14 @@ def make_conversations(self, data, system_prompt, model=None): return conversations def load_and_filter_dataset( - self, start, end, split=None, source=None, filter_difficulty=None, args=None + self, start, end, split=None, subset=None, difficulty=None, args=None ): - dataset = self.load_dataset(source=source, split=split).to_pandas() - if filter_difficulty or self.task_config.preprocess_config.difficulty: + dataset = self.load_dataset(subset=subset, split=split).to_pandas() + if difficulty or "difficulty" in self.task_config.preprocess_config: difficulty = ( - source - if filter_difficulty - else self.task_config.preprocess_config.difficulty + difficulty + if difficulty + else self.task_config.preprocess_config["difficulty"] ) dataset = dataset.filter( lambda example: example["difficulty"] == difficulty diff --git a/skythought/skythought_evals/util/common.py b/skythought/skythought_evals/util/common.py index e24bc23..8957615 100644 --- a/skythought/skythought_evals/util/common.py +++ b/skythought/skythought_evals/util/common.py @@ -1,6 +1,19 @@ import multiprocessing +import os +import random import re +import numpy as np +import torch + + +def set_seed(seed: int): + os.environ["PYTHONHASHSEED"] = str(seed) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + class TimeoutException(Exception): """Custom exception for function timeout.""" From c2944fe7e50125b62cfe4a5161d7b1dd124ea355 Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Sat, 1 Feb 2025 01:12:06 +0000 Subject: [PATCH 45/47] x Signed-off-by: SumanthRH --- skythought/skythought_evals/README.md | 22 +++++++++---------- skythought/skythought_evals/eval.py | 2 +- .../skythought_evals/tasks/numina/numina.yaml | 10 ++++----- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/skythought/skythought_evals/README.md b/skythought/skythought_evals/README.md index a18e759..9ae0113 100644 --- a/skythought/skythought_evals/README.md +++ b/skythought/skythought_evals/README.md @@ -27,13 +27,13 @@ The expected output is labeled_source_0_-1.json. We also provide instructions to Inference the results from QwQ on several datasets. In preview version, we use data from the following dataset. ```shell -python inference_and_check.py --task apps --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split test --source all --result-dir $SKYT_HOME/data --inference +python inference_and_check.py --task apps --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split test --difficulty all --result-dir $SKYT_HOME/data --inference -python inference_and_check.py --task taco --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source MEDIUM --filter-difficulty --result-dir $SKYT_HOME/data --inference +python inference_and_check.py --task taco --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --difficulty MEDIUM--result-dir $SKYT_HOME/data --inference -python inference_and_check.py --task taco --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split test --source all --result-dir $SKYT_HOME/data --inference +python inference_and_check.py --task taco --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split test --difficulty all --result-dir $SKYT_HOME/data --inference -python inference_and_check.py --task numina --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source math --filter-difficulty math --result-dir $SKYT_HOME/data --inference +python inference_and_check.py --task numina --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source math --filter-difficulty --result-dir $SKYT_HOME/data --inference python inference_and_check.py --task numina --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source amc_aime --filter-difficulty --result-dir $SKYT_HOME/data --inference @@ -48,7 +48,7 @@ python convert_format.py --input_dir $SKYT_HOME/data --keys keys.txt ### Step 3: Reject Sampling on the formatted data (Example Usage with previous script) ```shell -python inference_and_check.py --task apps --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split test --source all --result-dir $SKYT_HOME/data --check +python inference_and_check.py --task apps --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split test --subset all --result-dir $SKYT_HOME/data --check ``` Similar for other datasets. @@ -67,17 +67,17 @@ Currently we support distill and reject sampling from various self-hosted models #### Example Usage ```shell -python inference_and_check.py --task apps --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split test --source all --result-dir $SKYT_HOME/data +python inference_and_check.py --task apps --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split test --difficulty all --result-dir $SKYT_HOME/data -python inference_and_check.py --task taco --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --filter-difficulty MEDIUM --result-dir $SKYT_HOME/data +python inference_and_check.py --task taco --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --difficulty MEDIUM --result-dir $SKYT_HOME/data -python inference_and_check.py --task taco --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split test --source all --result-dir $SKYT_HOME/data +python inference_and_check.py --task taco --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split test --difficulty all --result-dir $SKYT_HOME/data -python inference_and_check.py --task numina --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source math --filter-difficulty true --result-dir $SKYT_HOME/data --math-difficulty-lower-bound 4 --math-difficulty-upper-bound 9 +python inference_and_check.py --task numina --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source math --filter-difficulty --result-dir $SKYT_HOME/data --math-difficulty-lower-bound 4 --math-difficulty-upper-bound 9 -python inference_and_check.py --task numina --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source amc_aime --filter-difficulty true --result-dir $SKYT_HOME/data --math-difficulty-lower-bound 1 --math-difficulty-upper-bound 9 +python inference_and_check.py --task numina --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source amc_aime --filter-difficulty --result-dir $SKYT_HOME/data --math-difficulty-lower-bound 1 --math-difficulty-upper-bound 9 -python inference_and_check.py --task numina --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --end 20000 --filter-difficulty olympiads --result-dir $SKYT_HOME/data --math-difficulty-lower-bound 9 --math-difficulty-upper-bound 9 +python inference_and_check.py --task numina --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --end 20000--source olympiads --filter-difficulty --result-dir $SKYT_HOME/data --math-difficulty-lower-bound 9 --math-difficulty-upper-bound 9 ``` #### Best-of-N Inference and Check diff --git a/skythought/skythought_evals/eval.py b/skythought/skythought_evals/eval.py index 411a697..f4e65d8 100644 --- a/skythought/skythought_evals/eval.py +++ b/skythought/skythought_evals/eval.py @@ -27,7 +27,7 @@ def parse_arguments(): default=None, help="Optional filter difficulty. Options: 'easy', 'medium', 'hard'.", ) - parser.add_argument("--source", type=str, help="Source for the dataset.") + parser.add_argument("--subset", type=str, help="Subset for the dataset.") parser.add_argument( "--output_file", required=True, diff --git a/skythought/skythought_evals/tasks/numina/numina.yaml b/skythought/skythought_evals/tasks/numina/numina.yaml index 9130c78..7431a7f 100644 --- a/skythought/skythought_evals/tasks/numina/numina.yaml +++ b/skythought/skythought_evals/tasks/numina/numina.yaml @@ -7,8 +7,8 @@ answer_key: solution templating_parameters: template: "Return your final response within \\boxed{{}}. {prompt}" # Optionally, you can filter the dataset by difficulty -preprocess_config: - filter_difficulty: true - math_difficulty_lower_bound: 4 - math_difficulty_upper_bound: 9 - source: math +# preprocess_config: +# filter_difficulty: true +# math_difficulty_lower_bound: 4 +# math_difficulty_upper_bound: 9 +# source: math From 8a397011885ce5dd947ce77227df001ca2e83945 Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Sat, 1 Feb 2025 21:28:51 +0000 Subject: [PATCH 46/47] x Signed-off-by: SumanthRH --- skythought/skythought_evals/eval.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skythought/skythought_evals/eval.py b/skythought/skythought_evals/eval.py index f4e65d8..3180a95 100644 --- a/skythought/skythought_evals/eval.py +++ b/skythought/skythought_evals/eval.py @@ -22,10 +22,10 @@ def parse_arguments(): ) parser.add_argument("--tp", type=int, default=8, help="Tensor Parallelism Degree") parser.add_argument( - "--filter-difficulty", + "--difficulty", type=str, default=None, - help="Optional filter difficulty. Options: 'easy', 'medium', 'hard'.", + help="Difficulty for the dataset. Options: 'easy', 'medium', 'hard'", ) parser.add_argument("--subset", type=str, help="Subset for the dataset.") parser.add_argument( From 503ea622188712fd6ac518c97f33132356f74732 Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Sat, 1 Feb 2025 23:05:47 +0000 Subject: [PATCH 47/47] x Signed-off-by: SumanthRH --- setup.py | 13 +++++-------- tests/{skythought_evals => evals}/__init__.py | 0 .../test_preprocessing.py | 0 3 files changed, 5 insertions(+), 8 deletions(-) rename tests/{skythought_evals => evals}/__init__.py (100%) rename tests/{skythought_evals => evals}/test_preprocessing.py (100%) diff --git a/setup.py b/setup.py index cdfebb7..e85572b 100644 --- a/setup.py +++ b/setup.py @@ -13,14 +13,11 @@ def get_requirements(): setuptools.setup( name="skythought_evals", version="0.0.1", - package_dir={ - "skythought_evals": "skythought/skythought_evals" - }, # map skythought_evals to skythought/skythought_evals - packages=["skythought_evals"] - + [ - f"skythought_evals.{pkg}" - for pkg in setuptools.find_packages(where="skythought/skythought_evals") - ], + package_dir={"": "skythought"}, + packages=setuptools.find_packages( + where="skythought", + include=["skythought_evals*"], # Only pick up skythought_evals, skip 'train' + ), install_requires=get_requirements(), python_requires=">=3.9,<3.12", # pyext doesn't work with python 3.12 ) diff --git a/tests/skythought_evals/__init__.py b/tests/evals/__init__.py similarity index 100% rename from tests/skythought_evals/__init__.py rename to tests/evals/__init__.py diff --git a/tests/skythought_evals/test_preprocessing.py b/tests/evals/test_preprocessing.py similarity index 100% rename from tests/skythought_evals/test_preprocessing.py rename to tests/evals/test_preprocessing.py