diff --git a/docs/examples.md b/docs/examples.md index a3c07abcf..424afaf23 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -260,3 +260,57 @@ for question in questions: response = agent.explain() print(response) ``` + +## Add Skills to the Agent + +You can add customs functions for the agent to use, allowing the agent to expand its capabilities. These custom functions can be seamlessly integrated with the agent's skills, enabling a wide range of user-defined operations. + +``` +import pandas as pd +from pandasai import Agent + +from pandasai.llm.openai import OpenAI +from pandasai.skills import skill + +employees_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Name": ["John", "Emma", "Liam", "Olivia", "William"], + "Department": ["HR", "Sales", "IT", "Marketing", "Finance"], +} + +salaries_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Salary": [5000, 6000, 4500, 7000, 5500], +} + +employees_df = pd.DataFrame(employees_data) +salaries_df = pd.DataFrame(salaries_data) + + +@skill( + name="Display employee salary", + description="Plots the employee salaries against names", + usage="Displays the plot having name on x axis and salaries on y axis", +) +def plot_salaries(merged_df: pd.DataFrame) -> str: + import matplotlib.pyplot as plt + + plt.bar(merged_df["Name"], merged_df["Salary"]) + plt.xlabel("Employee Name") + plt.ylabel("Salary") + plt.title("Employee Salaries") + plt.xticks(rotation=45) + plt.savefig("temp_chart.png") + plt.close() + + +llm = OpenAI("YOUR_API_KEY") +agent = Agent([employees_df, salaries_df], config={"llm": llm}, memory_size=10) + +agent.add_skills(plot_salaries) + +# Chat with the agent +response = agent.chat("Plot the employee salaries against names") +print(response) + +``` diff --git a/docs/skills.md b/docs/skills.md new file mode 100644 index 000000000..ae6488219 --- /dev/null +++ b/docs/skills.md @@ -0,0 +1,113 @@ +# Skills + +You can add customs functions for the agent to use, allowing the agent to expand its capabilities. These custom functions can be seamlessly integrated with the agent's skills, enabling a wide range of user-defined operations. + +## Example Usage + +```python + +import pandas as pd +from pandasai import Agent + +from pandasai.llm.openai import OpenAI +from pandasai.skills import skill + +employees_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Name": ["John", "Emma", "Liam", "Olivia", "William"], + "Department": ["HR", "Sales", "IT", "Marketing", "Finance"], +} + +salaries_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Salary": [5000, 6000, 4500, 7000, 5500], +} + +employees_df = pd.DataFrame(employees_data) +salaries_df = pd.DataFrame(salaries_data) + +# Function doc string to give more context to the model for use this skill +@skill +def plot_salaries(name: list[str], salaries: list[int]): + """ + Displays the bar chart having name on x axis and salaries on y axis + Args: + name (list[str]): Employee name + salaries (list[int]): Salaries + """ + # plot bars + import matplotlib.pyplot as plt + + plt.bar(name, salaries) + plt.xlabel("Employee Name") + plt.ylabel("Salary") + plt.title("Employee Salaries") + plt.xticks(rotation=45) + + + +llm = OpenAI("YOUR_API_KEY") +agent = Agent([employees_df, salaries_df], config={"llm": llm}, memory_size=10) + +agent.add_skills(plot_salaries) + +# Chat with the agent +response = agent.chat("Plot the employee salaries against names") + + +``` + +## Add Streamlit Skill + +```python +import pandas as pd +from pandasai import Agent + +from pandasai.llm.openai import OpenAI +from pandasai.skills import skill +import streamlit as st + +employees_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Name": ["John", "Emma", "Liam", "Olivia", "William"], + "Department": ["HR", "Sales", "IT", "Marketing", "Finance"], +} + +salaries_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Salary": [5000, 6000, 4500, 7000, 5500], +} + +employees_df = pd.DataFrame(employees_data) +salaries_df = pd.DataFrame(salaries_data) + +# Function doc string to give more context to the model for use this skill +@skill +def plot_salaries(name: list[str], salary: list[int]): + """ + Displays the bar chart having name on x axis and salaries on y axis using streamlit + Args: + name (list[str]): Employee name + salaries (list[int]): Salaries + """ + import matplotlib.pyplot as plt + + plt.bar(name, salary) + plt.xlabel("Employee Name") + plt.ylabel("Salary") + plt.title("Employee Salaries") + plt.xticks(rotation=45) + plt.savefig("temp_chart.png") + fig = plt.gcf() + st.pyplot(fig) + + +llm = OpenAI("YOUR_API_KEY") +agent = Agent([employees_df, salaries_df], config={"llm": llm}, memory_size=10) + +agent.add_skills(plot_salaries_using_streamlit) + +# Chat with the agent +response = agent.chat("Plot the employee salaries against names") +print(response) +``` diff --git a/examples/skills_example.py b/examples/skills_example.py new file mode 100644 index 000000000..e1df24d99 --- /dev/null +++ b/examples/skills_example.py @@ -0,0 +1,47 @@ +import pandas as pd +from pandasai import Agent + +from pandasai.llm.openai import OpenAI +from pandasai.skills import skill + +employees_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Name": ["John", "Emma", "Liam", "Olivia", "William"], + "Department": ["HR", "Sales", "IT", "Marketing", "Finance"], +} + +salaries_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Salary": [5000, 6000, 4500, 7000, 5500], +} + +employees_df = pd.DataFrame(employees_data) +salaries_df = pd.DataFrame(salaries_data) + + +# Add function docstring to give more context to model +@skill +def plot_salaries(name: list[str], salary: list[int]) -> str: + """ + Displays the bar chart having name on x axis and salaries on y axis using streamlit + Args: + name (list[str]): Employee name + salaries (list[int]): Salaries + """ + import matplotlib.pyplot as plt + + plt.bar(name, salary) + plt.xlabel("Employee Name") + plt.ylabel("Salary") + plt.title("Employee Salaries") + plt.xticks(rotation=45) + + +llm = OpenAI("YOUR-API-KEY") +agent = Agent([employees_df, salaries_df], config={"llm": llm}, memory_size=10) + +agent.add_skills(plot_salaries) + +# Chat with the agent +response = agent.chat("Plot the employee salaries against names") +print(response) diff --git a/mkdocs.yml b/mkdocs.yml index 86d71cd87..6808e430d 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -26,6 +26,7 @@ nav: - callbacks.md - custom-instructions.md - custom-prompts.md + - skills.md - custom-whitelisted-dependencies.md - Examples: - examples.md diff --git a/pandasai/__init__.py b/pandasai/__init__.py index 5e1a08ea3..2c71d8116 100644 --- a/pandasai/__init__.py +++ b/pandasai/__init__.py @@ -45,6 +45,7 @@ from .schemas.df_config import Config from .helpers.cache import Cache from .agent import Agent +from .skills import skill __version__ = importlib.metadata.version(__package__ or __name__) @@ -257,4 +258,11 @@ def clear_cache(filename: str = None): cache.clear() -__all__ = ["PandasAI", "SmartDataframe", "SmartDatalake", "Agent", "clear_cache"] +__all__ = [ + "PandasAI", + "SmartDataframe", + "SmartDatalake", + "Agent", + "clear_cache", + "skill", +] diff --git a/pandasai/agent/__init__.py b/pandasai/agent/__init__.py index 455da534d..52034bac0 100644 --- a/pandasai/agent/__init__.py +++ b/pandasai/agent/__init__.py @@ -1,5 +1,7 @@ import json from typing import Union, List, Optional + +from pandasai.skills import skill from ..helpers.df_info import DataFrameType from ..helpers.logger import Logger from ..helpers.memory import Memory @@ -47,6 +49,12 @@ def __init__( self._logger = self._lake.logger + def add_skills(self, *skills: List[skill]): + """ + Add Skills to PandasAI + """ + self._lake.add_skills(*skills) + def _call_llm_with_prompt(self, prompt: AbstractPrompt): """ Call LLM with prompt using error handling to retry based on config diff --git a/pandasai/assets/prompt_templates/generate_python_code.tmpl b/pandasai/assets/prompt_templates/generate_python_code.tmpl index 566ee388e..b0c337e44 100644 --- a/pandasai/assets/prompt_templates/generate_python_code.tmpl +++ b/pandasai/assets/prompt_templates/generate_python_code.tmpl @@ -10,7 +10,7 @@ This is the initial python function. Do not change the params. Given the context ```python {current_code} ``` - +{skills} Take a deep breath and reason step-by-step. Act as a senior data analyst. In the answer, you must never write the "technical" names of the tables. Based on the last message in the conversation: diff --git a/pandasai/helpers/code_manager.py b/pandasai/helpers/code_manager.py index 6c5e65eee..d015d6dac 100644 --- a/pandasai/helpers/code_manager.py +++ b/pandasai/helpers/code_manager.py @@ -6,6 +6,8 @@ import astor import pandas as pd +from pandasai.helpers.skills_manager import SkillsManager + from .node_visitors import AssignmentVisitor, CallVisitor from .save_chart import add_save_chart from .optional import import_dependency @@ -23,6 +25,29 @@ import traceback +class CodeExecutionContext: + _prompt_id: uuid.UUID = None + _skills_manager: SkillsManager = None + + def __init__(self, prompt_id: uuid.UUID, skills_manager: SkillsManager): + """ + Additional Context for code execution + Args: + prompt_id (uuid.UUID): prompt unique id + skill (List): list[functions] of skills added + """ + self._skills_manager = skills_manager + self._prompt_id = prompt_id + + @property + def prompt_id(self): + return self._prompt_id + + @property + def skills_manager(self): + return self._skills_manager + + class CodeManager: _dfs: List _middlewares: List[Middleware] = [ChartsMiddleware()] @@ -180,11 +205,7 @@ def _required_dfs(self, code: str) -> List[str]: required_dfs.append(None) return required_dfs - def execute_code( - self, - code: str, - prompt_id: uuid.UUID, - ) -> Any: + def execute_code(self, code: str, context: CodeExecutionContext) -> Any: """ Execute the python code generated by LLMs to answer the question about the input dataframe. Run the code in the current context and return the @@ -192,7 +213,8 @@ def execute_code( Args: code (str): Python code to execute. - prompt_id (uuid.UUID): UUID of the request. + context (CodeExecutionContext): Code Execution Context + with prompt id and skills. Returns: Any: The result of the code execution. The type of the result depends @@ -209,12 +231,15 @@ def execute_code( code = add_save_chart( code, logger=self._logger, - file_name=str(prompt_id), + file_name=str(context.prompt_id), save_charts_path=self._config.save_charts_path, ) + # Reset used skills + context.skills_manager.used_skills = [] + # Get the code to run removing unsafe imports and df overwrites - code_to_run = self._clean_code(code) + code_to_run = self._clean_code(code, context) self.last_code_executed = code_to_run self._logger.log( f""" @@ -228,6 +253,13 @@ def execute_code( # if the code does not need them dfs = self._required_dfs(code_to_run) environment: dict = self._get_environment() + + # Add Skills in the env + if len(context.skills_manager.used_skills) > 0: + for skill_func_name in context.skills_manager.used_skills: + skill = context.skills_manager.get_skill_by_func_name(skill_func_name) + environment[skill_func_name] = skill + environment["dfs"] = self._get_samples(dfs) caught_error = self._execute_catching_errors(code_to_run, environment) @@ -293,7 +325,6 @@ def _get_environment(self) -> dict: Returns (dict): A dictionary of environment variables """ - return { "pd": pd, **{ @@ -377,7 +408,7 @@ def _sanitize_analyze_data(self, analyze_data_node: ast.stmt) -> ast.stmt: analyze_data_node.body = sanitized_analyze_data return analyze_data_node - def _clean_code(self, code: str) -> str: + def _clean_code(self, code: str, context: CodeExecutionContext) -> str: """ A method to clean the code to prevent malicious code execution. @@ -400,11 +431,24 @@ def _clean_code(self, code: str) -> str: if isinstance(node, (ast.Import, ast.ImportFrom)): self._check_imports(node) continue + if isinstance(node, ast.FunctionDef) and node.name == "analyze_data": analyze_data_node = node sanitized_analyze_data = self._sanitize_analyze_data(analyze_data_node) + + # Walk inside the function def for used skills + if len(context.skills_manager.skills) > 0: + for node in ast.walk(analyze_data_node): + # Checks for function to get skill name + if isinstance(node, ast.Call) and isinstance( + node.func, ast.Name + ): + function_name = node.func.id + context.skills_manager.add_used_skill(function_name) + new_body.append(sanitized_analyze_data) continue + new_body.append(node) new_tree = ast.Module(body=new_body) diff --git a/pandasai/helpers/skills_manager.py b/pandasai/helpers/skills_manager.py new file mode 100644 index 000000000..c63c5ff6f --- /dev/null +++ b/pandasai/helpers/skills_manager.py @@ -0,0 +1,103 @@ +from typing import List + +# from pandasai.skills import skill + + +class SkillsManager: + """ + Manages Custom added Skills and tracks used skills for the query + """ + + _skills: List + _used_skills: List[str] + + def __init__(self) -> None: + self._skills = [] + self._used_skills = [] + + def add_skills(self, *skills): + """ + Add skills to the list of skills. If a skill with the same name + already exists, raise an error. + + Args: + *skills: Variable number of skill objects to add. + """ + for skill in skills: + if any( + existing_skill.name == skill.name for existing_skill in self._skills + ): + raise ValueError(f"Skill with name '{skill.name}' already exists.") + + self._skills.extend(skills) + + def skill_exists(self, name: str): + """ + Check if a skill with the given name exists in the list of skills. + + Args: + name (str): The name of the skill to check. + + Returns: + bool: True if a skill with the given name exists, False otherwise. + """ + return any(skill.name == name for skill in self._skills) + + def get_skill_by_func_name(self, name: str): + """ + Get a skill by its name. + + Args: + name (str): The name of the skill to retrieve. + + Returns: + Skill or None: The skill with the given name, or None if not found. + """ + for skill in self._skills: + if skill.name == name: + return skill + + return None + + def add_used_skill(self, skill: str): + if self.skill_exists(skill): + self._used_skills.append(skill) + + def __str__(self) -> str: + """ + Present all skills + Returns: + str: _description_ + """ + skills_repr = "" + for skill in self._skills: + skills_repr = skills_repr + skill.print + + return skills_repr + + def prompt_display(self) -> str: + """ + Displays skills for prompt + """ + if len(self._skills) == 0: + return + + return ( + """ +You can also use the following functions, if relevant: + +""" + + self.__str__() + ) + + @property + def used_skills(self): + return self._used_skills + + @used_skills.setter + def used_skills(self, value): + self._used_skills = value + + @property + def skills(self): + return self._skills diff --git a/pandasai/skills/__init__.py b/pandasai/skills/__init__.py new file mode 100644 index 000000000..cc82c5b2f --- /dev/null +++ b/pandasai/skills/__init__.py @@ -0,0 +1,32 @@ +import inspect + + +def skill(skill_function): + def wrapped_function(*args, **kwargs): + return skill_function(*args, **kwargs) + + wrapped_function.name = skill_function.__name__ + wrapped_function.func_def = ( + """def pandasai.skills.{funcion_name}{signature}""".format( + funcion_name=wrapped_function.name, + signature=str(inspect.signature(skill_function)), + ) + ) + + doc_string = skill_function.__doc__ + + wrapped_function.print = ( + """ + +{signature} +{doc_string} + +""" + ).format( + signature=wrapped_function.func_def, + doc_string=""" \"\"\"{0}\n \"\"\"""".format(doc_string) + if doc_string is not None + else "", + ) + + return wrapped_function diff --git a/pandasai/smart_dataframe/__init__.py b/pandasai/smart_dataframe/__init__.py index 4a024d3bc..28724c1e4 100644 --- a/pandasai/smart_dataframe/__init__.py +++ b/pandasai/smart_dataframe/__init__.py @@ -26,6 +26,7 @@ import pydantic from pandasai.helpers.df_validator import DfValidator +from pandasai.skills import skill from ..smart_datalake import SmartDatalake from ..schemas.df_config import Config @@ -322,6 +323,12 @@ def add_middlewares(self, *middlewares: Optional[Middleware]): """ self.lake.add_middlewares(*middlewares) + def add_skills(self, *skills: List[skill]): + """ + Add Skills to PandasAI + """ + self.lake.add_skills(*skills) + def chat(self, query: str, output_type: Optional[str] = None): """ Run a query on the dataframe. diff --git a/pandasai/smart_datalake/__init__.py b/pandasai/smart_datalake/__init__.py index 672dbed7c..69bbeb299 100644 --- a/pandasai/smart_datalake/__init__.py +++ b/pandasai/smart_datalake/__init__.py @@ -21,6 +21,9 @@ import logging import os import traceback +from pandasai.helpers.skills_manager import SkillsManager + +from pandasai.skills import skill from pandasai.helpers.query_exec_tracker import QueryExecTracker @@ -38,7 +41,7 @@ from ..prompts.correct_error_prompt import CorrectErrorPrompt from ..prompts.generate_python_code import GeneratePythonCodePrompt from typing import Union, List, Any, Type, Optional -from ..helpers.code_manager import CodeManager +from ..helpers.code_manager import CodeExecutionContext, CodeManager from ..middlewares.base import Middleware from ..helpers.df_info import DataFrameType from ..helpers.path import find_project_root @@ -51,11 +54,11 @@ class SmartDatalake: _llm: LLM _cache: Cache = None _logger: Logger - _start_time: float _last_prompt_id: uuid.UUID _conversation_id: uuid.UUID _code_manager: CodeManager _memory: Memory + _skills: SkillsManager _instance: str _query_exec_tracker: QueryExecTracker @@ -104,6 +107,8 @@ def __init__( logger=self.logger, ) + self._skills = SkillsManager() + if cache: self._cache = cache elif self._config.enable_cache: @@ -210,6 +215,12 @@ def add_middlewares(self, *middlewares: Optional[Middleware]): """ self._code_manager.add_middlewares(*middlewares) + def add_skills(self, *skills: List[skill]): + """ + Add Skills to PandasAI + """ + self._skills.add_skills(*skills) + def _assign_prompt_id(self): """Assign a prompt ID""" @@ -248,6 +259,11 @@ def _get_prompt( prompt.set_var("dfs", self._dfs) if "conversation" not in default_values: prompt.set_var("conversation", self._memory.get_conversation()) + + # Adds the skills to prompt if exist else display nothing + skills_prompt = self._skills.prompt_display() + prompt.set_var("skills", skills_prompt if skills_prompt is not None else "") + for key, value in default_values.items(): prompt.set_var(key, value) @@ -374,10 +390,10 @@ def chat(self, query: str, output_type: Optional[str] = None): while retry_count < self._config.max_retries: try: # Execute the code - result = self._query_exec_tracker.execute_func( - self._code_manager.execute_code, + context = CodeExecutionContext(self._last_prompt_id, self._skills) + result = self._code_manager.execute_code( code=code_to_run, - prompt_id=self._last_prompt_id, + context=context, ) break diff --git a/tests/prompts/test_generate_python_code_prompt.py b/tests/prompts/test_generate_python_code_prompt.py index 8cc3a0aa6..05ddb764e 100644 --- a/tests/prompts/test_generate_python_code_prompt.py +++ b/tests/prompts/test_generate_python_code_prompt.py @@ -51,6 +51,7 @@ def test_str_with_args(self, save_charts_path, output_type_hint): prompt.set_var("conversation", "Question") prompt.set_var("save_charts_path", save_charts_path) prompt.set_var("output_type_hint", output_type_hint) + prompt.set_var("skills", "") expected_prompt_content = f'''You are provided with the following pandas DataFrames: @@ -108,6 +109,7 @@ def test_advanced_reasoning_prompt(self): prompt.set_var("conversation", "Question") prompt.set_var("save_charts_path", "") prompt.set_var("output_type_hint", "") + prompt.set_var("skills", "") expected_prompt_content = f'''You are provided with the following pandas DataFrames: diff --git a/tests/skills/test_skills.py b/tests/skills/test_skills.py new file mode 100644 index 000000000..ed979951c --- /dev/null +++ b/tests/skills/test_skills.py @@ -0,0 +1,347 @@ +from typing import Optional +from unittest.mock import MagicMock, Mock, patch +import uuid +import pandas as pd + +import pytest +from pandasai.agent import Agent +from pandasai.helpers.code_manager import CodeExecutionContext, CodeManager + +from pandasai.helpers.skills_manager import SkillsManager +from pandasai.llm.fake import FakeLLM +from pandasai.skills import skill +from pandasai.smart_dataframe import SmartDataframe + + +class TestSkills: + @pytest.fixture + def llm(self, output: Optional[str] = None): + return FakeLLM(output=output) + + @pytest.fixture + def sample_df(self): + return pd.DataFrame( + { + "country": [ + "United States", + "United Kingdom", + "France", + "Germany", + "Italy", + "Spain", + "Canada", + "Australia", + "Japan", + "China", + ], + "gdp": [ + 19294482071552, + 2891615567872, + 2411255037952, + 3435817336832, + 1745433788416, + 1181205135360, + 1607402389504, + 1490967855104, + 4380756541440, + 14631844184064, + ], + "happiness_index": [ + 6.94, + 7.16, + 6.66, + 7.07, + 6.38, + 6.4, + 7.23, + 7.22, + 5.87, + 5.12, + ], + } + ) + + @pytest.fixture + def smart_dataframe(self, llm, sample_df): + return SmartDataframe(sample_df, config={"llm": llm, "enable_cache": False}) + + @pytest.fixture + def code_manager(self, smart_dataframe: SmartDataframe): + return smart_dataframe.lake._code_manager + + @pytest.fixture + def exec_context(self) -> MagicMock: + context = MagicMock(spec=CodeExecutionContext) + return context + + @pytest.fixture + def agent(self, llm, sample_df): + return Agent(sample_df, config={"llm": llm, "enable_cache": False}) + + def test_add_skills(self): + skills_manager = SkillsManager() + skill1 = Mock(name="SkillA", print="SkillA Print") + skill2 = Mock(name="SkillB", print="SkillB Print") + skills_manager.add_skills(skill1, skill2) + + # Ensure that skills are added + assert skill1 in skills_manager.skills + assert skill2 in skills_manager.skills + + # Test that adding a skill with the same name raises an error + try: + skills_manager.add_skills(skill1) + except ValueError as e: + assert str(e) == f"Skill with name '{skill1.name}' already exists." + else: + assert False, "Expected ValueError" + + def test_skill_exists(self): + skills_manager = SkillsManager() + skill1 = MagicMock() + skill2 = MagicMock() + skill1.name = "SkillA" + skill2.name = "SkillB" + skills_manager.add_skills(skill1, skill2) + + assert skills_manager.skill_exists("SkillA") + assert skills_manager.skill_exists("SkillB") + + # Test that a non-existing skill is not found + assert not skills_manager.skill_exists("SkillC") + + def test_get_skill_by_func_name(self): + skills_manager = SkillsManager() + skill1 = Mock() + skill2 = Mock() + skill1.name = "SkillA" + skill2.name = "SkillB" + skills_manager.add_skills(skill1, skill2) + + # Test that you can retrieve a skill by its function name + retrieved_skill = skills_manager.get_skill_by_func_name("SkillA") + assert retrieved_skill == skill1 + + # Test that a non-existing skill returns None + retrieved_skill = skills_manager.get_skill_by_func_name("SkillC") + assert retrieved_skill is None + + def test_add_used_skill(self): + skills_manager = SkillsManager() + skill1 = Mock() + skill2 = Mock() + skill1.name = "SkillA" + skill2.name = "SkillB" + skills_manager.add_skills(skill1, skill2) + + # Test adding used skills + skills_manager.add_used_skill("SkillA") + skills_manager.add_used_skill("SkillB") + + # Ensure used skills are added to the used_skills list + assert "SkillA" in skills_manager.used_skills + assert "SkillB" in skills_manager.used_skills + + def test_prompt_display(self): + skills_manager = SkillsManager() + skill1 = Mock() + skill2 = Mock() + skill1.name = "SkillA" + skill2.name = "SkillB" + skill1.print = "SkillA" + skill2.print = "SkillB" + skills_manager.add_skills(skill1, skill2) + + # Test prompt_display method when skills exist + prompt = skills_manager.prompt_display() + assert "You can also use the following functions" in prompt + + # Test prompt_display method when no skills exist + skills_manager._skills = [] + prompt = skills_manager.prompt_display() + assert prompt is None + + @patch("pandasai.skills.inspect.signature", return_value="(a, b, c)") + def test_skill_decorator(self, mock_inspect_signature): + # Define skills using the decorator + @skill + def skill_a(*args, **kwargs): + return "SkillA Result" + + @skill + def skill_b(*args, **kwargs): + return "SkillB Result" + + # Test the wrapped functions + assert skill_a() == "SkillA Result" + assert skill_b() == "SkillB Result" + + # Test the additional attributes added by the decorator + assert skill_a.name == "skill_a" + assert skill_b.name == "skill_b" + + assert skill_a.func_def == "def pandasai.skills.skill_a(a, b, c)" + assert skill_b.func_def == "def pandasai.skills.skill_b(a, b, c)" + + assert ( + skill_a.print + == """\n\ndef pandasai.skills.skill_a(a, b, c)\n\n\n""" # noqa: E501 + ) + assert ( + skill_b.print + == """\n\ndef pandasai.skills.skill_b(a, b, c)\n\n\n""" # noqa: E501 + ) + + @patch("pandasai.skills.inspect.signature", return_value="(a, b, c)") + def test_skill_decorator_test_codc(self, llm): + df = pd.DataFrame({"country": []}) + df = SmartDataframe(df, config={"llm": llm, "enable_cache": False}) + + # Define skills using the decorator + @skill + def plot_salaries(*args, **kwargs): + """ + Test skill A + Args: + arg(str) + """ + return "SkillA Result" + + function_def = """ + Test skill A + Args: + arg(str) +""" # noqa: E501 + + assert function_def in plot_salaries.print + + def test_add_skills_with_agent(self, agent: Agent): + # Define skills using the decorator + @skill + def skill_a(*args, **kwargs): + return "SkillA Result" + + @skill + def skill_b(*args, **kwargs): + return "SkillB Result" + + agent.add_skills(skill_a) + assert len(agent._lake._skills.skills) == 1 + + agent._lake._skills._skills = [] + agent.add_skills(skill_a, skill_b) + assert len(agent._lake._skills.skills) == 2 + + def test_add_skills_with_smartDataframe(self, smart_dataframe: SmartDataframe): + # Define skills using the decorator + @skill + def skill_a(*args, **kwargs): + return "SkillA Result" + + @skill + def skill_b(*args, **kwargs): + return "SkillB Result" + + smart_dataframe.add_skills(skill_a) + assert len(smart_dataframe._lake._skills.skills) == 1 + + smart_dataframe._lake._skills._skills = [] + smart_dataframe.add_skills(skill_a, skill_b) + assert len(smart_dataframe._lake._skills.skills) == 2 + + def test_run_prompt(self, llm): + df = pd.DataFrame({"country": []}) + df = SmartDataframe(df, config={"llm": llm, "enable_cache": False}) + + function_def = """ + +def pandasai.skills.plot_salaries(merged_df: pandas.core.frame.DataFrame) -> str + + +""" # noqa: E501 + + @skill + def plot_salaries(merged_df: pd.DataFrame) -> str: + import matplotlib.pyplot as plt + + plt.bar(merged_df["Name"], merged_df["Salary"]) + plt.xlabel("Employee Name") + plt.ylabel("Salary") + plt.title("Employee Salaries") + plt.xticks(rotation=45) + plt.savefig("temp_chart.png") + plt.close() + + df.add_skills(plot_salaries) + + df.chat("How many countries are in the dataframe?") + last_prompt = df.last_prompt + assert function_def in last_prompt + + def test_run_prompt_agent(self, agent): + function_def = """ + +def pandasai.skills.plot_salaries(merged_df: pandas.core.frame.DataFrame) -> str + + +""" # noqa: E501 + + @skill + def plot_salaries(merged_df: pd.DataFrame) -> str: + import matplotlib.pyplot as plt + + plt.bar(merged_df["Name"], merged_df["Salary"]) + plt.xlabel("Employee Name") + plt.ylabel("Salary") + plt.title("Employee Salaries") + plt.xticks(rotation=45) + plt.savefig("temp_chart.png") + plt.close() + + agent.add_skills(plot_salaries) + + agent.chat("How many countries are in the dataframe?") + last_prompt = agent._lake.last_prompt + + assert function_def in last_prompt + + def test_run_prompt_without_skills(self, agent): + agent.chat("How many countries are in the dataframe?") + + last_prompt = agent._lake.last_prompt + + assert "" not in last_prompt + assert "" not in last_prompt + assert ( + "You can also use the following functions, if relevant:" not in last_prompt + ) + + def test_code_exec_with_skills_no_use( + self, code_manager: CodeManager, exec_context: MagicMock + ): + code = """def analyze_data(dfs): + return {'type': 'number', 'value': 1 + 1}""" + skill1 = MagicMock() + skill1.name = "SkillA" + exec_context._skills_manager._skills = [skill1] + code_manager.execute_code(code, exec_context) + assert len(exec_context._skills_manager.used_skills) == 0 + + def test_code_exec_with_skills(self, code_manager: CodeManager): + code = """def analyze_data(dfs): + plot_salaries() + return {'type': 'number', 'value': 1 + 1}""" + + @skill + def plot_salaries() -> str: + return "plot_salaries" + + code_manager._middlewares = [] + + sm = SkillsManager() + sm.add_skills(plot_salaries) + exec_context = CodeExecutionContext(uuid.uuid4(), sm) + code_manager.execute_code(code, exec_context) + + assert len(exec_context._skills_manager.used_skills) == 1 + assert exec_context._skills_manager.used_skills[0] == "plot_salaries" diff --git a/tests/test_codemanager.py b/tests/test_codemanager.py index b5c39933a..0df494429 100644 --- a/tests/test_codemanager.py +++ b/tests/test_codemanager.py @@ -1,7 +1,6 @@ """Unit tests for the CodeManager class""" -import uuid from typing import Optional -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import pandas as pd import pytest @@ -11,7 +10,7 @@ from pandasai.smart_dataframe import SmartDataframe -from pandasai.helpers.code_manager import CodeManager +from pandasai.helpers.code_manager import CodeExecutionContext, CodeManager class TestCodeManager: @@ -72,23 +71,33 @@ def smart_dataframe(self, llm, sample_df): def code_manager(self, smart_dataframe: SmartDataframe): return smart_dataframe.lake._code_manager - def test_run_code_for_calculations(self, code_manager: CodeManager): + @pytest.fixture + def exec_context(self) -> MagicMock: + context = MagicMock(spec=CodeExecutionContext) + return context + + def test_run_code_for_calculations( + self, code_manager: CodeManager, exec_context: MagicMock + ): code = """def analyze_data(dfs): return {'type': 'number', 'value': 1 + 1}""" - - assert code_manager.execute_code(code, uuid.uuid4())["value"] == 2 + assert code_manager.execute_code(code, exec_context)["value"] == 2 assert code_manager.last_code_executed == code - def test_run_code_invalid_code(self, code_manager: CodeManager): + def test_run_code_invalid_code( + self, code_manager: CodeManager, exec_context: MagicMock + ): with pytest.raises(Exception): # noinspection PyStatementEffect - code_manager.execute_code("1+ ", uuid.uuid4())["value"] + code_manager.execute_code("1+ ", exec_context)["value"] - def test_clean_code_remove_builtins(self, code_manager: CodeManager): + def test_clean_code_remove_builtins( + self, code_manager: CodeManager, exec_context: MagicMock + ): builtins_code = """import set def analyze_data(dfs): return {'type': 'number', 'value': set([1, 2, 3])}""" - assert code_manager.execute_code(builtins_code, uuid.uuid4())["value"] == { + assert code_manager.execute_code(builtins_code, exec_context)["value"] == { 1, 2, 3, @@ -99,44 +108,57 @@ def analyze_data(dfs): return {'type': 'number', 'value': set([1, 2, 3])}""" ) - def test_clean_code_removes_jailbreak_code(self, code_manager: CodeManager): + def test_clean_code_removes_jailbreak_code( + self, code_manager: CodeManager, exec_context: MagicMock + ): malicious_code = """def analyze_data(dfs): __builtins__['str'].__class__.__mro__[-1].__subclasses__()[140].__init__.__globals__['system']('ls') print('hello world')""" assert ( - code_manager._clean_code(malicious_code) + code_manager._clean_code(malicious_code, exec_context) == """def analyze_data(dfs): print('hello world')""" ) - def test_clean_code_remove_environment_defaults(self, code_manager: CodeManager): + def test_clean_code_remove_environment_defaults( + self, code_manager: CodeManager, exec_context: MagicMock + ): pandas_code = """import pandas as pd print('hello world') """ - assert code_manager._clean_code(pandas_code) == "print('hello world')" + assert ( + code_manager._clean_code(pandas_code, exec_context) + == "print('hello world')" + ) - def test_clean_code_whitelist_import(self, code_manager: CodeManager): + def test_clean_code_whitelist_import( + self, code_manager: CodeManager, exec_context: MagicMock + ): """Test that an installed whitelisted library is added to the environment.""" safe_code = """ import numpy as np np.array() """ - assert code_manager._clean_code(safe_code) == "np.array()" + assert code_manager._clean_code(safe_code, exec_context) == "np.array()" - def test_clean_code_raise_bad_import_error(self, code_manager: CodeManager): + def test_clean_code_raise_bad_import_error( + self, code_manager: CodeManager, exec_context: MagicMock + ): malicious_code = """ import os print(os.listdir()) """ with pytest.raises(BadImportError): - code_manager.execute_code(malicious_code, uuid.uuid4()) + code_manager.execute_code(malicious_code, exec_context) - def test_remove_dfs_overwrites(self, code_manager: CodeManager): + def test_remove_dfs_overwrites( + self, code_manager: CodeManager, exec_context: MagicMock + ): hallucinated_code = """def analyze_data(dfs): dfs = [pd.DataFrame([1,2,3])] print(dfs)""" assert ( - code_manager._clean_code(hallucinated_code) + code_manager._clean_code(hallucinated_code, exec_context) == """def analyze_data(dfs): print(dfs)""" ) @@ -157,7 +179,9 @@ def test_exception_handling( ) assert smart_dataframe.last_error == "No code found in the answer." - def test_custom_whitelisted_dependencies(self, code_manager: CodeManager, llm): + def test_custom_whitelisted_dependencies( + self, code_manager: CodeManager, llm, exec_context: MagicMock + ): code = """ import my_custom_library def analyze_data(dfs: list): @@ -166,11 +190,11 @@ def analyze_data(dfs: list): llm._output = code with pytest.raises(BadImportError): - code_manager._clean_code(code) + code_manager._clean_code(code, exec_context) code_manager._config.custom_whitelisted_dependencies = ["my_custom_library"] assert ( - code_manager._clean_code(code) + code_manager._clean_code(code, exec_context) == """def analyze_data(dfs: list): my_custom_library.do_something()""" )