From 2bb2c12b23b6d8c29142ba0c9c070e35a85a0904 Mon Sep 17 00:00:00 2001 From: Niels Provos Date: Wed, 5 Feb 2025 15:38:07 -0800 Subject: [PATCH] feat: add a plan creation pattern this creates a sub-graph that takes an initial user request for plan and runs it through a refinement step --- src/planai/patterns/planner.py | 237 ++++++++++++++++++++++++++ tests/planai/patterns/test_planner.py | 171 +++++++++++++++++++ 2 files changed, 408 insertions(+) create mode 100644 src/planai/patterns/planner.py create mode 100644 tests/planai/patterns/test_planner.py diff --git a/src/planai/patterns/planner.py b/src/planai/patterns/planner.py new file mode 100644 index 0000000..3ffac8a --- /dev/null +++ b/src/planai/patterns/planner.py @@ -0,0 +1,237 @@ +from textwrap import dedent +from typing import List, Tuple, Type + +from llm_interface import LLMInterface +from pydantic import Field + +from ..cached_task import CachedTaskWorker +from ..graph import Graph +from ..graph_task import SubGraphWorker +from ..joined_task import InitialTaskWorker, JoinedTaskWorker +from ..llm_task import CachedLLMTaskWorker +from ..task import Task, TaskWorker + + +class PlanRequest(Task): + request: str = Field(..., description="The original request to create a plan for") + + +class PlanDraft(Task): + plan: str = Field(..., description="The draft plan in markdown format") + + +class PlanCritique(Task): + comprehensiveness: float = Field( + ..., description="Score for how comprehensive the plan is (0-1)" + ) + detail_orientation: float = Field( + ..., description="Score for how detailed the plan is (0-1)" + ) + goal_achievement: float = Field( + ..., + description="Score for how well the plan achieves the original request goals (0-1)", + ) + overall_score: float = Field(..., description="Combined score (0-1)") + improvement_suggestions: str = Field( + ..., description="Suggestions for improving the plan" + ) + + +class RefinementRequest(Task): + original_request: str = Field(..., description="The original request") + plans: List[str] = Field(..., description="The plans to be refined") + critiques: List[PlanCritique] = Field( + ..., description="The critiques for each plan" + ) + + +class FinalPlan(Task): + plan: str = Field(..., description="The final refined plan") + rationale: str = Field(..., description="Explanation of how the plan was refined") + + +class PlanEntryWorker(CachedTaskWorker): + output_types: List[Type[Task]] = [PlanRequest] + num_variations: int = Field(3, description="Number of plan variations to generate") + + def consume_work(self, task: PlanRequest): + for _ in range(self.num_variations): + self.publish_work(task=task.copy_public(), input_task=task) + + +class PlanCreator(CachedLLMTaskWorker): + output_types: List[Type[Task]] = [PlanDraft] + llm_input_type: Type[Task] = PlanRequest + prompt: str = dedent( + """ + Create a detailed plan in markdown format based on the following request: + {request} + + The plan should be: + - Comprehensive and well-structured + - Detailed and actionable + - Realistic and feasible + + Provide the plan in markdown format using appropriate headers, lists, and sections. + """ + ).strip() + + +class PlanCritiqueWorker(CachedLLMTaskWorker): + output_types: List[Type[Task]] = [PlanCritique] + llm_input_type: Type[Task] = PlanDraft + prompt: str = dedent( + """ + Evaluate the following plan based on these criteria: + + Plan to evaluate: + {plan} + + Original request: + {request} + + Score each criterion from 0 (worst) to 1 (best): + 1. Comprehensiveness: How complete and thorough is the plan? + 2. Detail Orientation: How specific and actionable are the steps? + 3. Goal Achievement: How well does the plan fulfill the goals of the original request? + + Provide improvement suggestions focused on the weakest aspects. + + Output should be JSON with: comprehensiveness, detail_orientation, goal_achievement, overall_score, improvement_suggestions + """ + ).strip() + + def post_process(self, response: PlanCritique, input_task: PlanDraft): + + comp = min(1, max(0, response.comprehensiveness)) + detail = min(1, max(0, response.detail_orientation)) + goal = min(1, max(0, response.goal_achievement)) + + # weight goal achievement more heavily + response.overall_score = 0.4 * comp + 0.4 * detail + 0.6 * goal + + return super().post_process(response, input_task) + + +class PlanCritiqueJoiner(JoinedTaskWorker): + output_types: List[Type[Task]] = [RefinementRequest] + join_type: Type[TaskWorker] = InitialTaskWorker + + def consume_work_joined(self, tasks: List[PlanCritique]): + if not tasks: + raise ValueError("No critiques to join") + + plans = [] + critiques = [] + original_request = "" + + for critique in tasks: + plan_draft: PlanDraft = critique.find_input_task(PlanDraft) + if plan_draft is None: + raise ValueError("PlanDraft not found in critique input tasks") + if not original_request: + plan_request: PlanRequest = critique.find_input_task(PlanRequest) + if plan_request is None: + raise ValueError("PlanRequest not found in critique input tasks") + original_request = plan_request.request + plans.append(plan_draft.plan) + critiques.append(critique) + + self.publish_work( + RefinementRequest( + original_request=original_request, plans=plans, critiques=critiques + ), + input_task=tasks[0], + ) + + +class PlanRefinementWorker(CachedLLMTaskWorker): + output_types: List[Type[Task]] = [FinalPlan] + llm_input_type: Type[Task] = RefinementRequest + use_xml: bool = True + prompt: str = dedent( + """ + Create a refined, optimized plan by combining the best elements of multiple plans. + + Original Request: + {original_request} + + Available Plans: + {plans} + + Plan Critiques: + {critiques} + + Create a final plan that: + 1. Incorporates the strongest elements from each plan + 2. Addresses the improvement suggestions from the critiques + 3. Forms a cohesive and comprehensive solution + + Provide your response as JSON with: + - plan: The final refined plan in markdown format + - rationale: Brief explanation of how you combined and improved the plans + """ + ).strip() + + +def create_planning_graph( + llm: LLMInterface, name: str = "PlanningWorker", num_variations: int = 3 +) -> Tuple[Graph, TaskWorker, TaskWorker]: + """Creates a planning graph with multiple workers for plan generation and refinement. + + This function sets up a directed graph of workers that collaborate to create and refine plans. + The graph includes workers for plan entry, creation, critique, joining critiques, and refinement. + + Args: + llm (LLMInterface): Language model interface used by the workers + name (str, optional): Base name for the graph. Defaults to "PlanningWorker" + num_variations (int, optional): Number of plan variations to generate. Defaults to 3 + + Returns: + Tuple[Graph, TaskWorker, TaskWorker]: A tuple containing: + - The constructed planning graph + - The entry worker node + - The refinement worker node + """ + graph = Graph(name=f"{name}Graph", strict=True) + + entry = PlanEntryWorker(num_variations=num_variations) + creator = PlanCreator(llm=llm) + critique = PlanCritiqueWorker(llm=llm) + joiner = PlanCritiqueJoiner() + refinement = PlanRefinementWorker(llm=llm) + + graph.add_workers(entry, creator, critique, joiner, refinement) + + graph.set_dependency(entry, creator).next(critique).next(joiner).next(refinement) + + return graph, entry, refinement + + +def create_planning_worker( + llm: LLMInterface, name: str = "PlanningWorker", num_variations: int = 2 +) -> TaskWorker: + """Creates a SubGraphWorker for plan generation and refinement. + + This worker creates a subgraph that: + 1. Generates multiple plan variations + 2. Critiques each plan + 3. Combines the best elements into a final plan + + Args: + llm: LLM interface for plan generation and analysis + name: Name for the worker + + Input Task: + PlanRequest: Task containing the original request to create a plan for + + Output Task: + FinalPlan: The refined final plan with rationale + """ + graph, entry, refinement = create_planning_graph( + llm=llm, name=name, num_variations=num_variations + ) + + return SubGraphWorker( + name=name, graph=graph, entry_worker=entry, exit_worker=refinement + ) diff --git a/tests/planai/patterns/test_planner.py b/tests/planai/patterns/test_planner.py new file mode 100644 index 0000000..9452410 --- /dev/null +++ b/tests/planai/patterns/test_planner.py @@ -0,0 +1,171 @@ +import unittest +from typing import Optional + +from planai import Graph, TaskWorker +from planai.patterns.planner import ( + FinalPlan, + PlanCritique, + PlanDraft, + PlanRequest, + RefinementRequest, + create_planning_graph, + create_planning_worker, +) +from planai.testing import ( + MockCache, + MockLLM, + MockLLMResponse, + inject_mock_cache, + unregister_output_type, +) + + +class TestPlanner(unittest.TestCase): + def setUp(self): + # Set up mock cache + self.mock_cache = MockCache(dont_store=True) + + # Set up mock LLM with different responses for each worker type + self.mock_llm = MockLLM( + responses=[ + # PlanCreator responses + MockLLMResponse( + pattern="Create a detailed plan.*", + response=PlanDraft( + plan="# Test Plan\n## Steps\n1. First step\n2. Second step" + ), + ), + # PlanCritiqueWorker responses + MockLLMResponse( + pattern="Evaluate the following plan.*", + response=PlanCritique( + comprehensiveness=0.8, + detail_orientation=0.7, + goal_achievement=0.9, + overall_score=0.8, + improvement_suggestions="Add more detail to step two", + ), + ), + # PlanRefinementWorker responses + MockLLMResponse( + pattern="Create a refined, optimized plan.*", + response=FinalPlan( + plan="# Refined Plan\n## Steps\n1. Detailed first step\n2. Enhanced second step", + rationale="Combined best elements and added detail", + ), + ), + ] + ) + + def test_planning_workflow(self): + # Create main graph and inject mock cache + graph = Graph(name="TestGraph") + planning = create_planning_worker(llm=self.mock_llm, name="TestPlanning") + graph.add_workers(planning) + graph.set_sink(planning, FinalPlan) + inject_mock_cache(graph, self.mock_cache) + + # Create initial request + request = PlanRequest(request="Create a plan for testing") + initial_work = [(planning, request)] + + # Run the graph + graph.run( + initial_tasks=initial_work, run_dashboard=False, display_terminal=False + ) + + # Get output tasks + output_tasks = graph.get_output_tasks() + + # Should have one final plan + self.assertEqual(len(output_tasks), 1) + final_plan = output_tasks[0] + self.assertIsInstance(final_plan, FinalPlan) + self.assertTrue("Refined Plan" in final_plan.plan) + self.assertTrue(final_plan.rationale) + + def test_planning_graph_workflow(self): + # Create graph using the plain graph version + graph, entry_worker, exit_worker = create_planning_graph( + llm=self.mock_llm, name="TestPlanning", num_variations=2 + ) + graph.set_sink(exit_worker, FinalPlan) + + # Inject mock cache into graph + inject_mock_cache(graph, self.mock_cache) + + # Create initial request + request = PlanRequest(request="Create a plan for testing") + initial_work = [(entry_worker, request)] + + # Run the graph + graph.run( + initial_tasks=initial_work, run_dashboard=False, display_terminal=False + ) + + # Get output tasks + output_tasks = graph.get_output_tasks() + + # Should have one final plan + self.assertEqual(len(output_tasks), 1) + final_plan = output_tasks[0] + self.assertIsInstance(final_plan, FinalPlan) + + def test_plan_variations(self): + # Test that correct number of variations are generated + graph, entry_worker, exit_worker = create_planning_graph( + llm=self.mock_llm, name="TestPlanning", num_variations=3 + ) + + # Add a sink to capture PlanDraft tasks + planner: Optional[TaskWorker] = graph.get_worker_by_output_type(PlanDraft) + assert planner is not None + unregister_output_type(planner, PlanDraft) + graph.set_sink(planner, PlanDraft) + inject_mock_cache(graph, self.mock_cache) + + # Create and run request + request = PlanRequest(request="Create a plan for testing") + initial_work = [(entry_worker, request)] + graph.run( + initial_tasks=initial_work, run_dashboard=False, display_terminal=False + ) + + # Get PlanDraft tasks + draft_tasks = [t for t in graph.get_output_tasks() if isinstance(t, PlanDraft)] + self.assertEqual(len(draft_tasks), 3) + + def test_critique_joining(self): + graph, entry_worker, exit_worker = create_planning_graph( + llm=self.mock_llm, name="TestPlanning", num_variations=2 + ) + + # Add a sink to capture RefinementRequest tasks + refiner = graph.get_worker_by_output_type(RefinementRequest) + assert refiner is not None + unregister_output_type(refiner, RefinementRequest) + graph.set_sink(refiner, RefinementRequest) + inject_mock_cache(graph, self.mock_cache) + + # Create and run request + request = PlanRequest(request="Create a plan for testing") + initial_work = [(entry_worker, request)] + graph.run( + initial_tasks=initial_work, run_dashboard=False, display_terminal=False + ) + + # Get RefinementRequest tasks + refinement_tasks = [ + t for t in graph.get_output_tasks() if isinstance(t, RefinementRequest) + ] + self.assertEqual(len(refinement_tasks), 1) + + # Verify the refinement request contains all plans and critiques + refinement = refinement_tasks[0] + self.assertEqual(len(refinement.plans), 2) + self.assertEqual(len(refinement.critiques), 2) + self.assertTrue(all(isinstance(c, PlanCritique) for c in refinement.critiques)) + + +if __name__ == "__main__": + unittest.main()