Skip to content

Commit

Permalink
docs: Add documentation for prompt optimization module
Browse files Browse the repository at this point in the history
  • Loading branch information
provos committed Sep 13, 2024
1 parent 54d5fa8 commit b2708e8
Showing 1 changed file with 174 additions and 1 deletion.
175 changes: 174 additions & 1 deletion src/planai/cli_optimize_prompt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,52 @@
# Copyright 2024 Niels Provos
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
PlanAI Prompt Optimization Module
This module implements the 'optimize-prompt' subcommand for the PlanAI tool, which automates the process of refining
prompts for Large Language Models (LLMs). It leverages more advanced LLMs to improve prompt effectiveness through
iterative optimization.
Key Features:
1. Automated Iteration: Runs multiple optimization cycles to progressively improve the prompt.
2. Real Data Integration: Utilizes debug logs with actual input-output pairs from production runs.
3. Dynamic Class Loading: Leverages PlanAI's use of Pydantic to dynamically load and use real production classes.
4. Scoring Mechanism: Employs an LLM with a scoring prompt to evaluate the accuracy and effectiveness of each iteration.
5. Adaptability: Designed to be agnostic to specific use cases, applicable to various LLM tasks.
The module includes several worker classes that form a graph-based optimization pipeline:
- PromptGenerationWorker: Generates improved prompts based on the optimization goal.
- PrepareInputWorker: Prepares input data for optimization from reference data.
- PromptPerformanceWorker: Analyzes the performance of prompts against the optimization goal.
- JoinPromptPerformanceOutput: Combines multiple performance outputs.
- AccumulateCritiqueOutput: Accumulates and ranks prompt critiques over multiple iterations.
- PromptImprovementWorker: Creates an improved prompt based on accumulated critiques.
Usage:
This module is typically invoked through the PlanAI CLI using the 'optimize-prompt' subcommand.
It requires specifying the target Python file, class name, debug log, and optimization goal.
Example:
planai optimize-prompt --python-file your_app.py --class-name YourLLMTaskWorker
--debug-log debug/YourLLMTaskWorker.json --goal-prompt "Your optimization goal here"
The module outputs optimized prompts as text files and corresponding metadata as JSON files.
Note: This tool requires a comprehensive debug log with diverse examples for effective optimization.
"""

import argparse
import hashlib
import json
Expand Down Expand Up @@ -89,6 +138,33 @@ class ImprovedPrompt(Task):


class PromptGenerationWorker(CachedLLMTaskWorker):
"""
A worker class responsible for generating improved prompts based on an optimization goal.
This class uses a Large Language Model (LLM) to analyze the provided prompt template
and suggest improvements to better meet the specified optimization goal.
Attributes:
output_types (List[Type[Task]]): List containing ImprovedPrompt as the output type.
prompt (str): The instruction prompt for the LLM to generate improved prompts.
Methods:
consume_work(task: PromptInput) -> ImprovedPrompt:
Processes the input task and generates an improved prompt.
post_process(response: ImprovedPrompt, input_task: PromptInput) -> ImprovedPrompt:
Sanitizes the generated prompt to preserve required keywords.
The worker ensures that:
- Existing {{keywords}} in the original prompt are maintained for .format() expansion.
- Literal curly braces in the prompt text are properly escaped using double braces.
- The focus is on improving the structure and approach of the prompt rather than specific subject matter.
- The generated prompt is a complete, standalone prompt that can be used as-is.
The output is a JSON object with 'prompt_template' and 'comment' fields, providing
the improved prompt and an explanation of the improvements made.
"""

output_types: List[Type[Task]] = [ImprovedPrompt]
prompt: str = dedent(
"""
Expand Down Expand Up @@ -123,12 +199,44 @@ def post_process(self, response: ImprovedPrompt, input_task: PromptInput):


class PrepareInputWorker(TaskWorker):
"""
A worker class responsible for preparing input data for prompt optimization.
This class takes reference data and transforms it into appropriate Pydantic task instances
that can be consumed by the target LLMTaskWorker class for which we are optimizing the prompt.
Attributes:
random_seed (int): The seed used for deterministic random selection of reference data.
num_examples (int): The number of examples to produce in each iteration.
_reference_data (List[Dict[str, Any]]): The raw reference data from which to create tasks.
_module (Any): The module containing the target LLMTaskWorker class.
Methods:
__init__(module: Any, task_name: str, reference_data: List[Dict[str, Any]], **data):
Initializes the worker with the necessary data and configuration.
consume_work(task: ImprovedPrompt) -> None:
Processes an improved prompt by selecting reference data and creating appropriate input tasks.
The worker performs several key functions:
1. It dynamically loads the appropriate Pydantic task class based on the target LLMTaskWorker.
2. It selects a subset of the reference data using a deterministic random process.
3. It transforms the selected reference data into instances of the appropriate Pydantic task class.
4. It ensures that the created tasks maintain the necessary provenance information for the optimization process.
This "massaging" of reference data is crucial because it allows the optimization process to use
real-world data in a format that exactly matches what the target LLMTaskWorker expects. This ensures
that the prompt optimization is performed under conditions that closely mimic actual usage scenarios.
The worker uses a combination of the improved prompt's hash and its own random seed to ensure
deterministic but varied selection of reference data across different optimization iterations.
"""

random_seed: int = Field(
42, description="The random seed to use for selecting reference data"
)
num_examples: int = Field(5, description="The number of examples to produce")
_reference_data: List[Dict[str, Any]] = PrivateAttr(default_factory=list)
_random_seed: int = PrivateAttr()
_module: Any = PrivateAttr()

def __init__(
Expand Down Expand Up @@ -211,6 +319,13 @@ def consume_work(self, task: PromptPerformanceInput):


class JoinPromptPerformanceOutput(JoinedTaskWorker):
"""
A worker class that aggregates multiple PromptPerformanceOutput tasks into a single CombinedPromptCritique.
This class is responsible for collecting the individual performance evaluations of a prompt
across multiple examples and combining them into a single, aggregated critique and score.
"""

output_types: List[Type[Task]] = [CombinedPromptCritique]
join_type: Type[TaskWorker] = PromptGenerationWorker

Expand Down Expand Up @@ -280,6 +395,22 @@ def consume_work(self, task: CombinedPromptCritique):


class PromptImprovementWorker(CachedLLMTaskWorker):
"""
A worker class that generates improved prompts based on aggregated critiques and validates them.
This class uses a more powerful LLM to create an improved prompt template based on the
critiques and scores of previous iterations. Crucially, it also validates the generated
prompt by attempting to instantiate it with the target LLMTaskWorker class.
The validation process involves:
1. Temporarily injecting the new prompt into the target LLMTaskWorker.
2. Attempting to generate a full prompt using real input data.
3. Catching any errors that occur during this process.
If the validation fails, the error is captured and is fed back to the LLM to generate a
prompts works correctly for the target LLMTaskWorker class.
"""

output_types: List[Type[Task]] = [ImprovedPrompt]
prompt: str = dedent(
"""
Expand Down Expand Up @@ -353,6 +484,48 @@ def extra_validation(
def optimize_prompt(
llm_fast: LLMInterface, llm_reason: LLMInterface, args: argparse.Namespace
):
"""
Orchestrates the prompt optimization process for a given LLMTaskWorker class.
This function sets up and executes a multi-step, iterative process to optimize the prompt
of a specified LLMTaskWorker class. It uses a combination of faster and more advanced LLMs
to generate, evaluate, and improve prompts based on real-world data and specified goals.
Parameters:
llm_fast (LLMInterface): A faster LLM used for initial prompt evaluations.
llm_reason (LLMInterface): A more advanced LLM used for in-depth analysis and improvements.
args (argparse.Namespace): Command-line arguments specifying optimization parameters.
The optimization process follows these main steps:
1. Load the target LLMTaskWorker class and its associated debug log data.
2. Set up a graph of specialized workers for different aspects of optimization:
- PromptGenerationWorker: Generates new prompt variations.
- PrepareInputWorker: Prepares real-world data for testing.
- The target LLMTaskWorker: Used to test prompts with real data.
- OutputAdapter: Adapts LLMTaskWorker output for analysis.
- PromptPerformanceWorker: Evaluates prompt performance.
- JoinPromptPerformanceOutput: Aggregates performance data.
- AccumulateCritiqueOutput: Accumulates and ranks critiques over iterations.
- PromptImprovementWorker: Creates improved prompts based on critiques.
3. Execute multiple iterations of this process, each time:
- Generating new prompts or improving existing ones.
- Testing these prompts against real-world data.
- Evaluating and scoring the performance of each prompt.
- Accumulating critiques and suggestions for improvement.
4. Output the top-performing prompts along with their scores and critiques.
The function handles loading necessary modules, setting up the optimization graph,
injecting prompt awareness into the target LLMTaskWorker, and managing the flow of
tasks through the optimization process.
Output:
- Writes the top-performing prompts to text files.
- Saves detailed metadata about each top prompt to JSON files.
Note:
This process requires a well-prepared debug log with diverse, representative examples
to ensure effective optimization across various use cases.
"""
if args.config:
# Read from configuration file
try:
Expand Down

0 comments on commit b2708e8

Please sign in to comment.