diff --git a/examples/aflow/README.md b/examples/aflow/README.md index c2abdff48a..9567e8cf4b 100644 --- a/examples/aflow/README.md +++ b/examples/aflow/README.md @@ -5,7 +5,7 @@ AFlow is a framework for automatically generating and optimizing Agentic Workflo [Read our paper on arXiv](https://arxiv.org/abs/2410.10762)

-Performance Of AFLOW +Performance Of AFLOW

## Framework Components @@ -17,7 +17,7 @@ AFlow is a framework for automatically generating and optimizing Agentic Workflo - **Evaluator**: Assesses workflow performance on given tasks. Provides feedback to guide the optimization process towards more effective workflows. See `metagpt/ext/aflow/scripts/evaluator.py` for details.

-Performance Of AFLOW +Performance Of AFLOW

## Datasets @@ -26,7 +26,7 @@ AFlow is a framework for automatically generating and optimizing Agentic Workflo We conducted experiments on six datasets (HumanEval, MBPP, GSM8K, MATH, HotpotQA, DROP) and provide their evaluation code. The data can be found in this [datasets](https://drive.google.com/uc?export=download&id=1DNoegtZiUhWtvkd2xoIuElmIi4ah7k8e) link, or you can download them using `metagpt/ext/aflow/data/download_data.py`

-Performance Of AFLOW +Performance Of AFLOW

### Custom Datasets @@ -34,31 +34,41 @@ For custom tasks, you can reference the code in the `metagpt/ext/aflow/benchmark ## Quick Start -1. Configure your search in `optimize.py`: - - Open `examples/aflow/optimize.py` - - Set the following parameters: +1. Configure optimization parameters: + - Use command line arguments or modify default parameters in `examples/aflow/optimize.py`: ```python - dataset: DatasetType = "MATH" # Ensure the type is consistent with DatasetType - sample: int = 4 # Sample Count, which means how many workflows will be resampled from generated workflows - question_type: QuestionType = "math" # Ensure the type is consistent with QuestionType - optimized_path: str = "metagpt/ext/aflow/scripts/optimized" # Optimized Result Save Path - initial_round: int = 1 # Corrected the case from Initial_round to initial_round - max_rounds: int = 20 # The max iteration of AFLOW. - check_convergence: bool = True # Whether Early Stop - validation_rounds: int = 5 # The validation rounds of AFLOW. - if_fisrt_optimize = True # You should change it to False after the first optimize. + --dataset MATH # Dataset type (HumanEval/MBPP/GSM8K/MATH/HotpotQA/DROP) + --sample 4 # Sample count - number of workflows to be resampled + --question_type math # Question type (math/code/qa) + --optimized_path PATH # Optimized result save path + --initial_round 1 # Initial round + --max_rounds 20 # Max iteration rounds for AFLOW + --check_convergence # Whether to enable early stop + --validation_rounds 5 # Validation rounds for AFLOW + --if_first_optimize # Set True for first optimization, False afterwards ``` - - Adjust these parameters according to your specific requirements and dataset -2. Set up parameters in `config/config2.yaml` (see `examples/aflow/config2.example.yaml` for reference) -3. Set the operator you want to use in `optimize.py` and in `optimized_path/template/operator.py`, `optimized_path/template/operator.json`. You can reference our implementation to add operators for specific datasets -4. When you first run, you can download the datasets and initial rounds by setting `download(["datasets", "initial_rounds"])` in `examples/aflow/optimize.py` + +2. Configure LLM parameters in `config/config2.yaml` (see `examples/aflow/config2.example.yaml` for reference) + +3. Set up operators in `optimize.py` and in `optimized_path/template/operator.py`, `optimized_path/template/operator.json`. You can reference our implementation to add operators for specific datasets + +4. For first-time use, download datasets and initial rounds by setting `download(["datasets", "initial_rounds"])` in `examples/aflow/optimize.py` + 5. (Optional) Add your custom dataset and corresponding evaluation function following the [Custom Datasets](#custom-datasets) section + 6. (Optional) If you want to use a portion of the validation data, you can set `va_list` in `examples/aflow/evaluator.py` -6. Run `python -m examples.aflow.optimize` to start the optimization process! +7. Run the optimization: + ```bash + # Using default parameters + python -m examples.aflow.optimize + + # Or with custom parameters + python -m examples.aflow.optimize --dataset MATH --sample 4 --question_type math + ``` ## Reproduce the Results in the Paper -1. We provide the raw data obtained from our experiments (link), including the workflows and prompts generated in each iteration, as well as their trajectories on the validation dataset. We also provide the optimal workflow for each dataset and the corresponding data on the test dataset. You can download these data using `metagpt/ext/aflow/data/download_data.py`. +1. We provide the raw data obtained from our experiments ([download link](https://drive.google.com/uc?export=download&id=1Sr5wjgKf3bN8OC7G6cO3ynzJqD4w6_Dv)), including the workflows and prompts generated in each iteration, as well as their trajectories on the validation dataset. We also provide the optimal workflow for each dataset and the corresponding data on the test dataset. You can download these data using `metagpt/ext/aflow/data/download_data.py`. 2. You can directly reproduce our experimental results by running the scripts in `examples/aflow/experiments`. diff --git a/examples/aflow/experiments/optimize_drop.py b/examples/aflow/experiments/optimize_drop.py index 80342e1a62..801c5222bb 100644 --- a/examples/aflow/experiments/optimize_drop.py +++ b/examples/aflow/experiments/optimize_drop.py @@ -6,14 +6,6 @@ from metagpt.configs.models_config import ModelsConfig from metagpt.ext.aflow.scripts.optimizer import DatasetType, Optimizer, QuestionType -# DatasetType, QuestionType, and OptimizerType definitions -# DatasetType = Literal["HumanEval", "MBPP", "GSM8K", "MATH", "HotpotQA", "DROP"] -# QuestionType = Literal["math", "code", "qa"] -# OptimizerType = Literal["Graph", "Test"] - -# When you fisrt use, please download the datasets and initial rounds; If you want to get a look of the results, please download the results. -# download(["datasets", "initial_rounds"]) - # Crucial Parameters dataset: DatasetType = "DROP" # Ensure the type is consistent with DatasetType sample: int = 4 # Sample Count, which means how many workflows will be resampled from generated workflows diff --git a/examples/aflow/experiments/optimize_gsm8k.py b/examples/aflow/experiments/optimize_gsm8k.py index 3895838ffa..e34fdb66d9 100644 --- a/examples/aflow/experiments/optimize_gsm8k.py +++ b/examples/aflow/experiments/optimize_gsm8k.py @@ -6,14 +6,6 @@ from metagpt.configs.models_config import ModelsConfig from metagpt.ext.aflow.scripts.optimizer import DatasetType, Optimizer, QuestionType -# DatasetType, QuestionType, and OptimizerType definitions -# DatasetType = Literal["HumanEval", "MBPP", "GSM8K", "MATH", "HotpotQA", "DROP"] -# QuestionType = Literal["math", "code", "qa"] -# OptimizerType = Literal["Graph", "Test"] - -# When you fisrt use, please download the datasets and initial rounds; If you want to get a look of the results, please download the results. -# download(["datasets", "initial_rounds"]) - # Crucial Parameters dataset: DatasetType = "GSM8K" # Ensure the type is consistent with DatasetType sample: int = 4 # Sample Count, which means how many workflows will be resampled from generated workflows diff --git a/examples/aflow/experiments/optimize_hotpotqa.py b/examples/aflow/experiments/optimize_hotpotqa.py index 447759f071..92d26ddd59 100644 --- a/examples/aflow/experiments/optimize_hotpotqa.py +++ b/examples/aflow/experiments/optimize_hotpotqa.py @@ -6,14 +6,6 @@ from metagpt.configs.models_config import ModelsConfig from metagpt.ext.aflow.scripts.optimizer import DatasetType, Optimizer, QuestionType -# DatasetType, QuestionType, and OptimizerType definitions -# DatasetType = Literal["HumanEval", "MBPP", "GSM8K", "MATH", "HotpotQA", "DROP"] -# QuestionType = Literal["math", "code", "qa"] -# OptimizerType = Literal["Graph", "Test"] - -# When you fisrt use, please download the datasets and initial rounds; If you want to get a look of the results, please download the results. -# download(["datasets", "initial_rounds"]) - # Crucial Parameters dataset: DatasetType = "HotpotQA" # Ensure the type is consistent with DatasetType sample: int = 4 # Sample Count, which means how many workflows will be resampled from generated workflows diff --git a/examples/aflow/experiments/optimize_humaneval.py b/examples/aflow/experiments/optimize_humaneval.py index 90a051aab1..6027e9ec86 100644 --- a/examples/aflow/experiments/optimize_humaneval.py +++ b/examples/aflow/experiments/optimize_humaneval.py @@ -6,14 +6,6 @@ from metagpt.configs.models_config import ModelsConfig from metagpt.ext.aflow.scripts.optimizer import DatasetType, Optimizer, QuestionType -# DatasetType, QuestionType, and OptimizerType definitions -# DatasetType = Literal["HumanEval", "MBPP", "GSM8K", "MATH", "HotpotQA", "DROP"] -# QuestionType = Literal["math", "code", "qa"] -# OptimizerType = Literal["Graph", "Test"] - -# When you fisrt use, please download the datasets and initial rounds; If you want to get a look of the results, please download the results. -# download(["datasets", "initial_rounds"]) - # Crucial Parameters dataset: DatasetType = "HumanEval" # Ensure the type is consistent with DatasetType sample: int = 4 # Sample Count, which means how many workflows will be resampled from generated workflows diff --git a/examples/aflow/experiments/optimize_math.py b/examples/aflow/experiments/optimize_math.py index 60e54d64a9..5d951c1680 100644 --- a/examples/aflow/experiments/optimize_math.py +++ b/examples/aflow/experiments/optimize_math.py @@ -6,14 +6,6 @@ from metagpt.configs.models_config import ModelsConfig from metagpt.ext.aflow.scripts.optimizer import DatasetType, Optimizer, QuestionType -# DatasetType, QuestionType, and OptimizerType definitions -# DatasetType = Literal["HumanEval", "MBPP", "GSM8K", "MATH", "HotpotQA", "DROP"] -# QuestionType = Literal["math", "code", "qa"] -# OptimizerType = Literal["Graph", "Test"] - -# When you fisrt use, please download the datasets and initial rounds; If you want to get a look of the results, please download the results. -# download(["datasets", "initial_rounds"]) - # Crucial Parameters dataset: DatasetType = "MATH" # Ensure the type is consistent with DatasetType sample: int = 4 # Sample Count, which means how many workflows will be resampled from generated workflows diff --git a/examples/aflow/experiments/optimize_mbpp.py b/examples/aflow/experiments/optimize_mbpp.py index b7e7d306f7..00c008bbf1 100644 --- a/examples/aflow/experiments/optimize_mbpp.py +++ b/examples/aflow/experiments/optimize_mbpp.py @@ -6,14 +6,6 @@ from metagpt.configs.models_config import ModelsConfig from metagpt.ext.aflow.scripts.optimizer import DatasetType, Optimizer, QuestionType -# DatasetType, QuestionType, and OptimizerType definitions -# DatasetType = Literal["HumanEval", "MBPP", "GSM8K", "MATH", "HotpotQA", "DROP"] -# QuestionType = Literal["math", "code", "qa"] -# OptimizerType = Literal["Graph", "Test"] - -# When you fisrt use, please download the datasets and initial rounds; If you want to get a look of the results, please download the results. -# download(["datasets", "initial_rounds"]) - # Crucial Parameters dataset: DatasetType = "MBPP" # Ensure the type is consistent with DatasetType sample: int = 4 # Sample Count, which means how many workflows will be resampled from generated workflows diff --git a/examples/aflow/optimize.py b/examples/aflow/optimize.py index d2bfbc7753..65b194344c 100644 --- a/examples/aflow/optimize.py +++ b/examples/aflow/optimize.py @@ -3,25 +3,33 @@ # @Author : didi # @Desc : Entrance of AFlow. +import argparse + from metagpt.configs.models_config import ModelsConfig from metagpt.ext.aflow.data.download_data import download -from metagpt.ext.aflow.scripts.optimizer import DatasetType, Optimizer, QuestionType +from metagpt.ext.aflow.scripts.optimizer import Optimizer # DatasetType, QuestionType, and OptimizerType definitions # DatasetType = Literal["HumanEval", "MBPP", "GSM8K", "MATH", "HotpotQA", "DROP"] # QuestionType = Literal["math", "code", "qa"] # OptimizerType = Literal["Graph", "Test"] -# Crucial Parameters -dataset: DatasetType = "MATH" # Ensure the type is consistent with DatasetType -sample: int = 4 # Sample Count, which means how many workflows will be resampled from generated workflows -question_type: QuestionType = "math" # Ensure the type is consistent with QuestionType -optimized_path: str = "metagpt/ext/aflow/scripts/optimized" # Optimized Result Save Path -initial_round: int = 1 # Corrected the case from Initial_round to initial_round -max_rounds: int = 20 # The max iteration of AFLOW. -check_convergence: bool = True # Whether Early Stop -validation_rounds: int = 5 # The validation rounds of AFLOW. -if_fisrt_optimize = True # You should change it to False after the first optimize. + +def parse_args(): + parser = argparse.ArgumentParser(description="AFlow Optimizer") + parser.add_argument("--dataset", type=str, default="MATH", help="Dataset type") + parser.add_argument("--sample", type=int, default=4, help="Sample count") + parser.add_argument("--question_type", type=str, default="math", help="Question type") + parser.add_argument( + "--optimized_path", type=str, default="metagpt/ext/aflow/scripts/optimized", help="Optimized result save path" + ) + parser.add_argument("--initial_round", type=int, default=1, help="Initial round") + parser.add_argument("--max_rounds", type=int, default=20, help="Max iteration rounds") + parser.add_argument("--check_convergence", type=bool, default=True, help="Whether to enable early stop") + parser.add_argument("--validation_rounds", type=int, default=5, help="Validation rounds") + parser.add_argument("--if_first_optimize", type=bool, default=True, help="Whether this is first optimization") + return parser.parse_args() + # Config llm model, you can modify `config/config2.yaml` to use more llms. mini_llm_config = ModelsConfig.default().get("gpt-4o-mini") @@ -37,24 +45,26 @@ "Programmer", # It's for math ] -# Create an optimizer instance -optimizer = Optimizer( - dataset=dataset, # Config dataset - question_type=question_type, # Config Question Type - opt_llm_config=claude_llm_config, # Config Optimizer LLM - exec_llm_config=mini_llm_config, # Config Execution LLM - check_convergence=check_convergence, # Whether Early Stop - operators=operators, # Config Operators you want to use - optimized_path=optimized_path, # Config Optimized workflow's file path - sample=sample, # Only Top(sample) rounds will be selected. - initial_round=initial_round, # Optimize from initial round - max_rounds=max_rounds, # The max iteration of AFLOW. - validation_rounds=validation_rounds, # The validation rounds of AFLOW. -) - if __name__ == "__main__": + args = parse_args() + + # Create an optimizer instance + optimizer = Optimizer( + dataset=args.dataset, # Config dataset + question_type=args.question_type, # Config Question Type + opt_llm_config=claude_llm_config, # Config Optimizer LLM + exec_llm_config=mini_llm_config, # Config Execution LLM + check_convergence=args.check_convergence, # Whether Early Stop + operators=operators, # Config Operators you want to use + optimized_path=args.optimized_path, # Config Optimized workflow's file path + sample=args.sample, # Only Top(sample) rounds will be selected. + initial_round=args.initial_round, # Optimize from initial round + max_rounds=args.max_rounds, # The max iteration of AFLOW. + validation_rounds=args.validation_rounds, # The validation rounds of AFLOW. + ) + # When you fisrt use, please download the datasets and initial rounds; If you want to get a look of the results, please download the results. - download(["datasets", "initial_rounds"], if_first_download=if_fisrt_optimize) + download(["datasets", "initial_rounds"], if_first_download=args.if_first_optimize) # Optimize workflow via setting the optimizer's mode to 'Graph' optimizer.optimize("Graph") # Test workflow via setting the optimizer's mode to 'Test' diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index c286b2fdd3..ab190b736f 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -19,12 +19,12 @@ from metagpt.actions.action_outcls_registry import register_action_outcls from metagpt.const import USE_CONFIG_TIMEOUT -from metagpt.ext.aflow.scripts.utils import sanitize from metagpt.llm import BaseLLM from metagpt.logs import logger from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess from metagpt.utils.common import OutputParser, general_after_log from metagpt.utils.human_interaction import HumanInteraction +from metagpt.utils.sanitize import sanitize class ReviewMode(Enum): @@ -527,7 +527,9 @@ def xml_compile(self, context): """ return context - async def code_fill(self, context, function_name=None, timeout=USE_CONFIG_TIMEOUT): + async def code_fill( + self, context: str, function_name: Optional[str] = None, timeout: int = USE_CONFIG_TIMEOUT + ) -> Dict[str, str]: """ Fill CodeBlock Using ``` ``` """ @@ -538,21 +540,21 @@ async def code_fill(self, context, function_name=None, timeout=USE_CONFIG_TIMEOU result = {field_name: extracted_code} return result - async def single_fill(self, context): + async def single_fill(self, context: str) -> Dict[str, str]: field_name = self.get_field_name() prompt = context content = await self.llm.aask(prompt) result = {field_name: content} return result - async def xml_fill(self, context): + async def xml_fill(self, context: str) -> Dict[str, Any]: """ - 使用XML标签填充上下文并根据字段类型进行转换,包括字符串、整数、布尔值、列表和字典类型 + Fill context with XML tags and convert according to field types, including string, integer, boolean, list and dict types """ field_names = self.get_field_names() field_types = self.get_field_types() - extracted_data = {} + extracted_data: Dict[str, Any] = {} content = await self.llm.aask(context) for field_name in field_names: diff --git a/metagpt/ext/aflow/benchmark/benchmark.py b/metagpt/ext/aflow/benchmark/benchmark.py index abdf546f57..b5692f01e6 100644 --- a/metagpt/ext/aflow/benchmark/benchmark.py +++ b/metagpt/ext/aflow/benchmark/benchmark.py @@ -3,6 +3,7 @@ import os from abc import ABC, abstractmethod from datetime import datetime +from pathlib import Path from typing import Any, Callable, List, Tuple import aiofiles @@ -10,6 +11,7 @@ from tqdm.asyncio import tqdm_asyncio from metagpt.logs import logger +from metagpt.utils.common import write_json_file class BaseBenchmark(ABC): @@ -18,6 +20,9 @@ def __init__(self, name: str, file_path: str, log_path: str): self.file_path = file_path self.log_path = log_path + PASS = "PASS" + FAIL = "FAIL" + async def load_data(self, specific_indices: List[int] = None) -> List[dict]: data = [] async with aiofiles.open(self.file_path, mode="r", encoding="utf-8") as file: @@ -55,9 +60,9 @@ def log_mismatch( "extracted_output": extracted_output, "extract_answer_code": extract_answer_code, } - log_file = os.path.join(self.log_path, "log.json") - if os.path.exists(log_file): - with open(log_file, "r", encoding="utf-8") as f: + log_file = Path(self.log_path) / "log.json" + if log_file.exists(): + with log_file.open("r", encoding="utf-8") as f: try: data = json.load(f) except json.JSONDecodeError: @@ -65,8 +70,7 @@ def log_mismatch( else: data = [] data.append(log_data) - with open(log_file, "w", encoding="utf-8") as f: - json.dump(data, f, indent=4, ensure_ascii=False) + write_json_file(log_file, data, encoding="utf-8", indent=4) @abstractmethod async def evaluate_problem(self, problem: dict, graph: Callable) -> Tuple[Any, ...]: diff --git a/metagpt/ext/aflow/benchmark/humaneval.py b/metagpt/ext/aflow/benchmark/humaneval.py index 36771ad7a8..b54add260f 100644 --- a/metagpt/ext/aflow/benchmark/humaneval.py +++ b/metagpt/ext/aflow/benchmark/humaneval.py @@ -6,17 +6,14 @@ from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark -from metagpt.ext.aflow.scripts.utils import sanitize from metagpt.logs import logger +from metagpt.utils.sanitize import sanitize class HumanEvalBenchmark(BaseBenchmark): def __init__(self, name: str, file_path: str, log_path: str): super().__init__(name, file_path, log_path) - PASS = "PASS" - FAIL = "FAIL" - class TimeoutError(Exception): pass diff --git a/metagpt/ext/aflow/benchmark/mbpp.py b/metagpt/ext/aflow/benchmark/mbpp.py index 62279ec188..c3628b0240 100644 --- a/metagpt/ext/aflow/benchmark/mbpp.py +++ b/metagpt/ext/aflow/benchmark/mbpp.py @@ -5,17 +5,14 @@ from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark -from metagpt.ext.aflow.scripts.utils import sanitize from metagpt.logs import logger +from metagpt.utils.sanitize import sanitize class MBPPBenchmark(BaseBenchmark): def __init__(self, name: str, file_path: str, log_path: str): super().__init__(name, file_path, log_path) - PASS = "PASS" - FAIL = "FAIL" - class TimeoutError(Exception): pass diff --git a/metagpt/ext/aflow/benchmark/utils.py b/metagpt/ext/aflow/benchmark/utils.py index 944fde6bec..60cbe5580e 100644 --- a/metagpt/ext/aflow/benchmark/utils.py +++ b/metagpt/ext/aflow/benchmark/utils.py @@ -4,7 +4,6 @@ @Time : 2024/7/24 16:37 @Author : didi @File : utils.py -@Acknowledgement https://github.com/evalplus/evalplus/blob/master/evalplus/sanitize.py """ import json @@ -12,6 +11,8 @@ import numpy as np +from metagpt.utils.common import write_json_file + def generate_random_indices(n, n_samples, test=False): """ @@ -41,13 +42,6 @@ def split_data_set(file_path, samples, test=False): return data -# save data into a jsonl file -def save_data(data, file_path): - with open(file_path, "w") as file: - for d in data: - file.write(json.dumps(d) + "\n") - - def log_mismatch(problem, expected_output, prediction, predicted_number, path): log_data = { "question": problem, @@ -74,5 +68,4 @@ def log_mismatch(problem, expected_output, prediction, predicted_number, path): data.append(log_data) # 将数据写回到log.json文件 - with open(log_file, "w", encoding="utf-8") as f: - json.dump(data, f, indent=4, ensure_ascii=False) + write_json_file(log_file, data, encoding="utf-8", indent=4) diff --git a/metagpt/ext/aflow/scripts/utils.py b/metagpt/ext/aflow/scripts/utils.py index 36332ffecb..bc97f08188 100644 --- a/metagpt/ext/aflow/scripts/utils.py +++ b/metagpt/ext/aflow/scripts/utils.py @@ -2,62 +2,11 @@ @Time : 2024/7/24 16:37 @Author : didi @File : utils.py -@Acknowledgement https://github.com/evalplus/evalplus/blob/master/evalplus/sanitize.py """ -import ast import json import re -import traceback -from enum import Enum -from typing import Any, Dict, Generator, List, Optional, Set, Tuple - -import tree_sitter_python -from tree_sitter import Language, Node, Parser - - -def extract_task_id(task_id: str) -> int: - """Extract the numeric part of the task_id.""" - match = re.search(r"/(\d+)", task_id) - return int(match.group(1)) if match else 0 - - -def get_hotpotqa(path: str): - # Parses each jsonl line and yields it as a dictionary - def parse_jsonl(path): - with open(path) as f: - for line in f: - yield json.loads(line) - - datas = list(parse_jsonl(path)) - return {data["_id"]: data for data in datas} - - -def sort_json_by_key(input_file: str, output_file: str, key: str = "task_id"): - """ - Read a JSONL file, sort the entries based on task_id, and write to a new JSONL file. - - :param input_file: Path to the input JSONL file - :param output_file: Path to the output JSONL file - """ - # Read and parse the JSONL file - with open(input_file, "r") as f: - data = [json.loads(line) for line in f] - - # Sort the data based on the numeric part of task_id - sorted_data = sorted(data, key=lambda x: extract_task_id(x[key])) - - # Write the sorted data to a new JSONL file - with open(output_file, "w") as f: - for item in sorted_data: - f.write(json.dumps(item) + "\n") - - -def parse_python_literal(s): - try: - return ast.literal_eval(s) - except (ValueError, SyntaxError): - return s +from typing import Any, List, Tuple def extract_test_cases_from_jsonl(entry_point: str, dataset: str = "HumanEval"): @@ -168,172 +117,3 @@ def test_check(): test_check() """ return tester_function - - -class NodeType(Enum): - CLASS = "class_definition" - FUNCTION = "function_definition" - IMPORT = ["import_statement", "import_from_statement"] - IDENTIFIER = "identifier" - ATTRIBUTE = "attribute" - RETURN = "return_statement" - EXPRESSION = "expression_statement" - ASSIGNMENT = "assignment" - - -def traverse_tree(node: Node) -> Generator[Node, None, None]: - """ - Traverse the tree structure starting from the given node. - - :param node: The root node to start the traversal from. - :return: A generator object that yields nodes in the tree. - """ - cursor = node.walk() - depth = 0 - - visited_children = False - while True: - if not visited_children: - yield cursor.node - if not cursor.goto_first_child(): - depth += 1 - visited_children = True - elif cursor.goto_next_sibling(): - visited_children = False - elif not cursor.goto_parent() or depth == 0: - break - else: - depth -= 1 - - -def syntax_check(code, verbose=False): - try: - ast.parse(code) - return True - except (SyntaxError, MemoryError): - if verbose: - traceback.print_exc() - return False - - -def code_extract(text: str) -> str: - lines = text.split("\n") - longest_line_pair = (0, 0) - longest_so_far = 0 - - for i in range(len(lines)): - for j in range(i + 1, len(lines)): - current_lines = "\n".join(lines[i : j + 1]) - if syntax_check(current_lines): - current_length = sum(1 for line in lines[i : j + 1] if line.strip()) - if current_length > longest_so_far: - longest_so_far = current_length - longest_line_pair = (i, j) - - return "\n".join(lines[longest_line_pair[0] : longest_line_pair[1] + 1]) - - -def get_definition_name(node: Node) -> str: - for child in node.children: - if child.type == NodeType.IDENTIFIER.value: - return child.text.decode("utf8") - - -def has_return_statement(node: Node) -> bool: - traverse_nodes = traverse_tree(node) - for node in traverse_nodes: - if node.type == NodeType.RETURN.value: - return True - return False - - -def get_deps(nodes: List[Tuple[str, Node]]) -> Dict[str, Set[str]]: - def dfs_get_deps(node: Node, deps: Set[str]) -> None: - for child in node.children: - if child.type == NodeType.IDENTIFIER.value: - deps.add(child.text.decode("utf8")) - else: - dfs_get_deps(child, deps) - - name2deps = {} - for name, node in nodes: - deps = set() - dfs_get_deps(node, deps) - name2deps[name] = deps - return name2deps - - -def get_function_dependency(entrypoint: str, call_graph: Dict[str, str]) -> Set[str]: - queue = [entrypoint] - visited = {entrypoint} - while queue: - current = queue.pop(0) - if current not in call_graph: - continue - for neighbour in call_graph[current]: - if neighbour not in visited: - visited.add(neighbour) - queue.append(neighbour) - return visited - - -def sanitize(code: str, entrypoint: Optional[str] = None) -> str: - """ - Sanitize and extract relevant parts of the given Python code. - This function parses the input code, extracts import statements, class and function definitions, - and variable assignments. If an entrypoint is provided, it only includes definitions that are - reachable from the entrypoint in the call graph. - - :param code: The input Python code as a string. - :param entrypoint: Optional name of a function to use as the entrypoint for dependency analysis. - :return: A sanitized version of the input code, containing only relevant parts. - """ - code = code_extract(code) - code_bytes = bytes(code, "utf8") - parser = Parser(Language(tree_sitter_python.language())) - tree = parser.parse(code_bytes) - class_names = set() - function_names = set() - variable_names = set() - - root_node = tree.root_node - import_nodes = [] - definition_nodes = [] - - for child in root_node.children: - if child.type in NodeType.IMPORT.value: - import_nodes.append(child) - elif child.type == NodeType.CLASS.value: - name = get_definition_name(child) - if not (name in class_names or name in variable_names or name in function_names): - definition_nodes.append((name, child)) - class_names.add(name) - elif child.type == NodeType.FUNCTION.value: - name = get_definition_name(child) - if not (name in function_names or name in variable_names or name in class_names) and has_return_statement( - child - ): - definition_nodes.append((name, child)) - function_names.add(get_definition_name(child)) - elif child.type == NodeType.EXPRESSION.value and child.children[0].type == NodeType.ASSIGNMENT.value: - subchild = child.children[0] - name = get_definition_name(subchild) - if not (name in variable_names or name in function_names or name in class_names): - definition_nodes.append((name, subchild)) - variable_names.add(name) - - if entrypoint: - name2deps = get_deps(definition_nodes) - reacheable = get_function_dependency(entrypoint, name2deps) - - sanitized_output = b"" - - for node in import_nodes: - sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n" - - for pair in definition_nodes: - name, node = pair - if entrypoint and name not in reacheable: - continue - sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n" - return sanitized_output[:-1].decode("utf8") diff --git a/metagpt/utils/sanitize.py b/metagpt/utils/sanitize.py new file mode 100644 index 0000000000..a9becbb98f --- /dev/null +++ b/metagpt/utils/sanitize.py @@ -0,0 +1,183 @@ +""" +@Time : 2024/7/24 16:37 +@Author : didi +@File : utils.py +@Acknowledgement https://github.com/evalplus/evalplus/blob/master/evalplus/sanitize.py +""" + +import ast +import traceback +from enum import Enum +from typing import Dict, Generator, List, Optional, Set, Tuple + +import tree_sitter_python +from tree_sitter import Language, Node, Parser + + +class NodeType(Enum): + CLASS = "class_definition" + FUNCTION = "function_definition" + IMPORT = ["import_statement", "import_from_statement"] + IDENTIFIER = "identifier" + ATTRIBUTE = "attribute" + RETURN = "return_statement" + EXPRESSION = "expression_statement" + ASSIGNMENT = "assignment" + + +def traverse_tree(node: Node) -> Generator[Node, None, None]: + """ + Traverse the tree structure starting from the given node. + + :param node: The root node to start the traversal from. + :return: A generator object that yields nodes in the tree. + """ + cursor = node.walk() + depth = 0 + + visited_children = False + while True: + if not visited_children: + yield cursor.node + if not cursor.goto_first_child(): + depth += 1 + visited_children = True + elif cursor.goto_next_sibling(): + visited_children = False + elif not cursor.goto_parent() or depth == 0: + break + else: + depth -= 1 + + +def syntax_check(code, verbose=False): + try: + ast.parse(code) + return True + except (SyntaxError, MemoryError): + if verbose: + traceback.print_exc() + return False + + +def code_extract(text: str) -> str: + lines = text.split("\n") + longest_line_pair = (0, 0) + longest_so_far = 0 + + for i in range(len(lines)): + for j in range(i + 1, len(lines)): + current_lines = "\n".join(lines[i : j + 1]) + if syntax_check(current_lines): + current_length = sum(1 for line in lines[i : j + 1] if line.strip()) + if current_length > longest_so_far: + longest_so_far = current_length + longest_line_pair = (i, j) + + return "\n".join(lines[longest_line_pair[0] : longest_line_pair[1] + 1]) + + +def get_definition_name(node: Node) -> str: + for child in node.children: + if child.type == NodeType.IDENTIFIER.value: + return child.text.decode("utf8") + + +def has_return_statement(node: Node) -> bool: + traverse_nodes = traverse_tree(node) + for node in traverse_nodes: + if node.type == NodeType.RETURN.value: + return True + return False + + +def get_deps(nodes: List[Tuple[str, Node]]) -> Dict[str, Set[str]]: + def dfs_get_deps(node: Node, deps: Set[str]) -> None: + for child in node.children: + if child.type == NodeType.IDENTIFIER.value: + deps.add(child.text.decode("utf8")) + else: + dfs_get_deps(child, deps) + + name2deps = {} + for name, node in nodes: + deps = set() + dfs_get_deps(node, deps) + name2deps[name] = deps + return name2deps + + +def get_function_dependency(entrypoint: str, call_graph: Dict[str, str]) -> Set[str]: + queue = [entrypoint] + visited = {entrypoint} + while queue: + current = queue.pop(0) + if current not in call_graph: + continue + for neighbour in call_graph[current]: + if neighbour not in visited: + visited.add(neighbour) + queue.append(neighbour) + return visited + + +def sanitize(code: str, entrypoint: Optional[str] = None) -> str: + """ + Sanitize and extract relevant parts of the given Python code. + This function parses the input code, extracts import statements, class and function definitions, + and variable assignments. If an entrypoint is provided, it only includes definitions that are + reachable from the entrypoint in the call graph. + + :param code: The input Python code as a string. + :param entrypoint: Optional name of a function to use as the entrypoint for dependency analysis. + :return: A sanitized version of the input code, containing only relevant parts. + """ + code = code_extract(code) + code_bytes = bytes(code, "utf8") + parser = Parser(Language(tree_sitter_python.language())) + tree = parser.parse(code_bytes) + class_names = set() + function_names = set() + variable_names = set() + + root_node = tree.root_node + import_nodes = [] + definition_nodes = [] + + for child in root_node.children: + if child.type in NodeType.IMPORT.value: + import_nodes.append(child) + elif child.type == NodeType.CLASS.value: + name = get_definition_name(child) + if not (name in class_names or name in variable_names or name in function_names): + definition_nodes.append((name, child)) + class_names.add(name) + elif child.type == NodeType.FUNCTION.value: + name = get_definition_name(child) + if not (name in function_names or name in variable_names or name in class_names) and has_return_statement( + child + ): + definition_nodes.append((name, child)) + function_names.add(get_definition_name(child)) + elif child.type == NodeType.EXPRESSION.value and child.children[0].type == NodeType.ASSIGNMENT.value: + subchild = child.children[0] + name = get_definition_name(subchild) + if not (name in variable_names or name in function_names or name in class_names): + definition_nodes.append((name, subchild)) + variable_names.add(name) + + if entrypoint: + name2deps = get_deps(definition_nodes) + reacheable = get_function_dependency(entrypoint, name2deps) + + sanitized_output = b"" + + for node in import_nodes: + sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n" + + for pair in definition_nodes: + name, node = pair + if entrypoint and name not in reacheable: + continue + sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n" + return sanitized_output[:-1].decode("utf8")