diff --git a/docs/examples.md b/docs/examples.md
index a3c07abcf..424afaf23 100644
--- a/docs/examples.md
+++ b/docs/examples.md
@@ -260,3 +260,57 @@ for question in questions:
response = agent.explain()
print(response)
```
+
+## Add Skills to the Agent
+
+You can add customs functions for the agent to use, allowing the agent to expand its capabilities. These custom functions can be seamlessly integrated with the agent's skills, enabling a wide range of user-defined operations.
+
+```
+import pandas as pd
+from pandasai import Agent
+
+from pandasai.llm.openai import OpenAI
+from pandasai.skills import skill
+
+employees_data = {
+ "EmployeeID": [1, 2, 3, 4, 5],
+ "Name": ["John", "Emma", "Liam", "Olivia", "William"],
+ "Department": ["HR", "Sales", "IT", "Marketing", "Finance"],
+}
+
+salaries_data = {
+ "EmployeeID": [1, 2, 3, 4, 5],
+ "Salary": [5000, 6000, 4500, 7000, 5500],
+}
+
+employees_df = pd.DataFrame(employees_data)
+salaries_df = pd.DataFrame(salaries_data)
+
+
+@skill(
+ name="Display employee salary",
+ description="Plots the employee salaries against names",
+ usage="Displays the plot having name on x axis and salaries on y axis",
+)
+def plot_salaries(merged_df: pd.DataFrame) -> str:
+ import matplotlib.pyplot as plt
+
+ plt.bar(merged_df["Name"], merged_df["Salary"])
+ plt.xlabel("Employee Name")
+ plt.ylabel("Salary")
+ plt.title("Employee Salaries")
+ plt.xticks(rotation=45)
+ plt.savefig("temp_chart.png")
+ plt.close()
+
+
+llm = OpenAI("YOUR_API_KEY")
+agent = Agent([employees_df, salaries_df], config={"llm": llm}, memory_size=10)
+
+agent.add_skills(plot_salaries)
+
+# Chat with the agent
+response = agent.chat("Plot the employee salaries against names")
+print(response)
+
+```
diff --git a/docs/skills.md b/docs/skills.md
new file mode 100644
index 000000000..ae6488219
--- /dev/null
+++ b/docs/skills.md
@@ -0,0 +1,113 @@
+# Skills
+
+You can add customs functions for the agent to use, allowing the agent to expand its capabilities. These custom functions can be seamlessly integrated with the agent's skills, enabling a wide range of user-defined operations.
+
+## Example Usage
+
+```python
+
+import pandas as pd
+from pandasai import Agent
+
+from pandasai.llm.openai import OpenAI
+from pandasai.skills import skill
+
+employees_data = {
+ "EmployeeID": [1, 2, 3, 4, 5],
+ "Name": ["John", "Emma", "Liam", "Olivia", "William"],
+ "Department": ["HR", "Sales", "IT", "Marketing", "Finance"],
+}
+
+salaries_data = {
+ "EmployeeID": [1, 2, 3, 4, 5],
+ "Salary": [5000, 6000, 4500, 7000, 5500],
+}
+
+employees_df = pd.DataFrame(employees_data)
+salaries_df = pd.DataFrame(salaries_data)
+
+# Function doc string to give more context to the model for use this skill
+@skill
+def plot_salaries(name: list[str], salaries: list[int]):
+ """
+ Displays the bar chart having name on x axis and salaries on y axis
+ Args:
+ name (list[str]): Employee name
+ salaries (list[int]): Salaries
+ """
+ # plot bars
+ import matplotlib.pyplot as plt
+
+ plt.bar(name, salaries)
+ plt.xlabel("Employee Name")
+ plt.ylabel("Salary")
+ plt.title("Employee Salaries")
+ plt.xticks(rotation=45)
+
+
+
+llm = OpenAI("YOUR_API_KEY")
+agent = Agent([employees_df, salaries_df], config={"llm": llm}, memory_size=10)
+
+agent.add_skills(plot_salaries)
+
+# Chat with the agent
+response = agent.chat("Plot the employee salaries against names")
+
+
+```
+
+## Add Streamlit Skill
+
+```python
+import pandas as pd
+from pandasai import Agent
+
+from pandasai.llm.openai import OpenAI
+from pandasai.skills import skill
+import streamlit as st
+
+employees_data = {
+ "EmployeeID": [1, 2, 3, 4, 5],
+ "Name": ["John", "Emma", "Liam", "Olivia", "William"],
+ "Department": ["HR", "Sales", "IT", "Marketing", "Finance"],
+}
+
+salaries_data = {
+ "EmployeeID": [1, 2, 3, 4, 5],
+ "Salary": [5000, 6000, 4500, 7000, 5500],
+}
+
+employees_df = pd.DataFrame(employees_data)
+salaries_df = pd.DataFrame(salaries_data)
+
+# Function doc string to give more context to the model for use this skill
+@skill
+def plot_salaries(name: list[str], salary: list[int]):
+ """
+ Displays the bar chart having name on x axis and salaries on y axis using streamlit
+ Args:
+ name (list[str]): Employee name
+ salaries (list[int]): Salaries
+ """
+ import matplotlib.pyplot as plt
+
+ plt.bar(name, salary)
+ plt.xlabel("Employee Name")
+ plt.ylabel("Salary")
+ plt.title("Employee Salaries")
+ plt.xticks(rotation=45)
+ plt.savefig("temp_chart.png")
+ fig = plt.gcf()
+ st.pyplot(fig)
+
+
+llm = OpenAI("YOUR_API_KEY")
+agent = Agent([employees_df, salaries_df], config={"llm": llm}, memory_size=10)
+
+agent.add_skills(plot_salaries_using_streamlit)
+
+# Chat with the agent
+response = agent.chat("Plot the employee salaries against names")
+print(response)
+```
diff --git a/examples/skills_example.py b/examples/skills_example.py
new file mode 100644
index 000000000..e1df24d99
--- /dev/null
+++ b/examples/skills_example.py
@@ -0,0 +1,47 @@
+import pandas as pd
+from pandasai import Agent
+
+from pandasai.llm.openai import OpenAI
+from pandasai.skills import skill
+
+employees_data = {
+ "EmployeeID": [1, 2, 3, 4, 5],
+ "Name": ["John", "Emma", "Liam", "Olivia", "William"],
+ "Department": ["HR", "Sales", "IT", "Marketing", "Finance"],
+}
+
+salaries_data = {
+ "EmployeeID": [1, 2, 3, 4, 5],
+ "Salary": [5000, 6000, 4500, 7000, 5500],
+}
+
+employees_df = pd.DataFrame(employees_data)
+salaries_df = pd.DataFrame(salaries_data)
+
+
+# Add function docstring to give more context to model
+@skill
+def plot_salaries(name: list[str], salary: list[int]) -> str:
+ """
+ Displays the bar chart having name on x axis and salaries on y axis using streamlit
+ Args:
+ name (list[str]): Employee name
+ salaries (list[int]): Salaries
+ """
+ import matplotlib.pyplot as plt
+
+ plt.bar(name, salary)
+ plt.xlabel("Employee Name")
+ plt.ylabel("Salary")
+ plt.title("Employee Salaries")
+ plt.xticks(rotation=45)
+
+
+llm = OpenAI("YOUR-API-KEY")
+agent = Agent([employees_df, salaries_df], config={"llm": llm}, memory_size=10)
+
+agent.add_skills(plot_salaries)
+
+# Chat with the agent
+response = agent.chat("Plot the employee salaries against names")
+print(response)
diff --git a/mkdocs.yml b/mkdocs.yml
index 86d71cd87..6808e430d 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -26,6 +26,7 @@ nav:
- callbacks.md
- custom-instructions.md
- custom-prompts.md
+ - skills.md
- custom-whitelisted-dependencies.md
- Examples:
- examples.md
diff --git a/pandasai/__init__.py b/pandasai/__init__.py
index 5e1a08ea3..2c71d8116 100644
--- a/pandasai/__init__.py
+++ b/pandasai/__init__.py
@@ -45,6 +45,7 @@
from .schemas.df_config import Config
from .helpers.cache import Cache
from .agent import Agent
+from .skills import skill
__version__ = importlib.metadata.version(__package__ or __name__)
@@ -257,4 +258,11 @@ def clear_cache(filename: str = None):
cache.clear()
-__all__ = ["PandasAI", "SmartDataframe", "SmartDatalake", "Agent", "clear_cache"]
+__all__ = [
+ "PandasAI",
+ "SmartDataframe",
+ "SmartDatalake",
+ "Agent",
+ "clear_cache",
+ "skill",
+]
diff --git a/pandasai/agent/__init__.py b/pandasai/agent/__init__.py
index 455da534d..52034bac0 100644
--- a/pandasai/agent/__init__.py
+++ b/pandasai/agent/__init__.py
@@ -1,5 +1,7 @@
import json
from typing import Union, List, Optional
+
+from pandasai.skills import skill
from ..helpers.df_info import DataFrameType
from ..helpers.logger import Logger
from ..helpers.memory import Memory
@@ -47,6 +49,12 @@ def __init__(
self._logger = self._lake.logger
+ def add_skills(self, *skills: List[skill]):
+ """
+ Add Skills to PandasAI
+ """
+ self._lake.add_skills(*skills)
+
def _call_llm_with_prompt(self, prompt: AbstractPrompt):
"""
Call LLM with prompt using error handling to retry based on config
diff --git a/pandasai/assets/prompt_templates/generate_python_code.tmpl b/pandasai/assets/prompt_templates/generate_python_code.tmpl
index 566ee388e..b0c337e44 100644
--- a/pandasai/assets/prompt_templates/generate_python_code.tmpl
+++ b/pandasai/assets/prompt_templates/generate_python_code.tmpl
@@ -10,7 +10,7 @@ This is the initial python function. Do not change the params. Given the context
```python
{current_code}
```
-
+{skills}
Take a deep breath and reason step-by-step. Act as a senior data analyst.
In the answer, you must never write the "technical" names of the tables.
Based on the last message in the conversation:
diff --git a/pandasai/helpers/code_manager.py b/pandasai/helpers/code_manager.py
index 6c5e65eee..d015d6dac 100644
--- a/pandasai/helpers/code_manager.py
+++ b/pandasai/helpers/code_manager.py
@@ -6,6 +6,8 @@
import astor
import pandas as pd
+from pandasai.helpers.skills_manager import SkillsManager
+
from .node_visitors import AssignmentVisitor, CallVisitor
from .save_chart import add_save_chart
from .optional import import_dependency
@@ -23,6 +25,29 @@
import traceback
+class CodeExecutionContext:
+ _prompt_id: uuid.UUID = None
+ _skills_manager: SkillsManager = None
+
+ def __init__(self, prompt_id: uuid.UUID, skills_manager: SkillsManager):
+ """
+ Additional Context for code execution
+ Args:
+ prompt_id (uuid.UUID): prompt unique id
+ skill (List): list[functions] of skills added
+ """
+ self._skills_manager = skills_manager
+ self._prompt_id = prompt_id
+
+ @property
+ def prompt_id(self):
+ return self._prompt_id
+
+ @property
+ def skills_manager(self):
+ return self._skills_manager
+
+
class CodeManager:
_dfs: List
_middlewares: List[Middleware] = [ChartsMiddleware()]
@@ -180,11 +205,7 @@ def _required_dfs(self, code: str) -> List[str]:
required_dfs.append(None)
return required_dfs
- def execute_code(
- self,
- code: str,
- prompt_id: uuid.UUID,
- ) -> Any:
+ def execute_code(self, code: str, context: CodeExecutionContext) -> Any:
"""
Execute the python code generated by LLMs to answer the question
about the input dataframe. Run the code in the current context and return the
@@ -192,7 +213,8 @@ def execute_code(
Args:
code (str): Python code to execute.
- prompt_id (uuid.UUID): UUID of the request.
+ context (CodeExecutionContext): Code Execution Context
+ with prompt id and skills.
Returns:
Any: The result of the code execution. The type of the result depends
@@ -209,12 +231,15 @@ def execute_code(
code = add_save_chart(
code,
logger=self._logger,
- file_name=str(prompt_id),
+ file_name=str(context.prompt_id),
save_charts_path=self._config.save_charts_path,
)
+ # Reset used skills
+ context.skills_manager.used_skills = []
+
# Get the code to run removing unsafe imports and df overwrites
- code_to_run = self._clean_code(code)
+ code_to_run = self._clean_code(code, context)
self.last_code_executed = code_to_run
self._logger.log(
f"""
@@ -228,6 +253,13 @@ def execute_code(
# if the code does not need them
dfs = self._required_dfs(code_to_run)
environment: dict = self._get_environment()
+
+ # Add Skills in the env
+ if len(context.skills_manager.used_skills) > 0:
+ for skill_func_name in context.skills_manager.used_skills:
+ skill = context.skills_manager.get_skill_by_func_name(skill_func_name)
+ environment[skill_func_name] = skill
+
environment["dfs"] = self._get_samples(dfs)
caught_error = self._execute_catching_errors(code_to_run, environment)
@@ -293,7 +325,6 @@ def _get_environment(self) -> dict:
Returns (dict): A dictionary of environment variables
"""
-
return {
"pd": pd,
**{
@@ -377,7 +408,7 @@ def _sanitize_analyze_data(self, analyze_data_node: ast.stmt) -> ast.stmt:
analyze_data_node.body = sanitized_analyze_data
return analyze_data_node
- def _clean_code(self, code: str) -> str:
+ def _clean_code(self, code: str, context: CodeExecutionContext) -> str:
"""
A method to clean the code to prevent malicious code execution.
@@ -400,11 +431,24 @@ def _clean_code(self, code: str) -> str:
if isinstance(node, (ast.Import, ast.ImportFrom)):
self._check_imports(node)
continue
+
if isinstance(node, ast.FunctionDef) and node.name == "analyze_data":
analyze_data_node = node
sanitized_analyze_data = self._sanitize_analyze_data(analyze_data_node)
+
+ # Walk inside the function def for used skills
+ if len(context.skills_manager.skills) > 0:
+ for node in ast.walk(analyze_data_node):
+ # Checks for function to get skill name
+ if isinstance(node, ast.Call) and isinstance(
+ node.func, ast.Name
+ ):
+ function_name = node.func.id
+ context.skills_manager.add_used_skill(function_name)
+
new_body.append(sanitized_analyze_data)
continue
+
new_body.append(node)
new_tree = ast.Module(body=new_body)
diff --git a/pandasai/helpers/skills_manager.py b/pandasai/helpers/skills_manager.py
new file mode 100644
index 000000000..c63c5ff6f
--- /dev/null
+++ b/pandasai/helpers/skills_manager.py
@@ -0,0 +1,103 @@
+from typing import List
+
+# from pandasai.skills import skill
+
+
+class SkillsManager:
+ """
+ Manages Custom added Skills and tracks used skills for the query
+ """
+
+ _skills: List
+ _used_skills: List[str]
+
+ def __init__(self) -> None:
+ self._skills = []
+ self._used_skills = []
+
+ def add_skills(self, *skills):
+ """
+ Add skills to the list of skills. If a skill with the same name
+ already exists, raise an error.
+
+ Args:
+ *skills: Variable number of skill objects to add.
+ """
+ for skill in skills:
+ if any(
+ existing_skill.name == skill.name for existing_skill in self._skills
+ ):
+ raise ValueError(f"Skill with name '{skill.name}' already exists.")
+
+ self._skills.extend(skills)
+
+ def skill_exists(self, name: str):
+ """
+ Check if a skill with the given name exists in the list of skills.
+
+ Args:
+ name (str): The name of the skill to check.
+
+ Returns:
+ bool: True if a skill with the given name exists, False otherwise.
+ """
+ return any(skill.name == name for skill in self._skills)
+
+ def get_skill_by_func_name(self, name: str):
+ """
+ Get a skill by its name.
+
+ Args:
+ name (str): The name of the skill to retrieve.
+
+ Returns:
+ Skill or None: The skill with the given name, or None if not found.
+ """
+ for skill in self._skills:
+ if skill.name == name:
+ return skill
+
+ return None
+
+ def add_used_skill(self, skill: str):
+ if self.skill_exists(skill):
+ self._used_skills.append(skill)
+
+ def __str__(self) -> str:
+ """
+ Present all skills
+ Returns:
+ str: _description_
+ """
+ skills_repr = ""
+ for skill in self._skills:
+ skills_repr = skills_repr + skill.print
+
+ return skills_repr
+
+ def prompt_display(self) -> str:
+ """
+ Displays skills for prompt
+ """
+ if len(self._skills) == 0:
+ return
+
+ return (
+ """
+You can also use the following functions, if relevant:
+
+"""
+ + self.__str__()
+ )
+
+ @property
+ def used_skills(self):
+ return self._used_skills
+
+ @used_skills.setter
+ def used_skills(self, value):
+ self._used_skills = value
+
+ @property
+ def skills(self):
+ return self._skills
diff --git a/pandasai/skills/__init__.py b/pandasai/skills/__init__.py
new file mode 100644
index 000000000..cc82c5b2f
--- /dev/null
+++ b/pandasai/skills/__init__.py
@@ -0,0 +1,32 @@
+import inspect
+
+
+def skill(skill_function):
+ def wrapped_function(*args, **kwargs):
+ return skill_function(*args, **kwargs)
+
+ wrapped_function.name = skill_function.__name__
+ wrapped_function.func_def = (
+ """def pandasai.skills.{funcion_name}{signature}""".format(
+ funcion_name=wrapped_function.name,
+ signature=str(inspect.signature(skill_function)),
+ )
+ )
+
+ doc_string = skill_function.__doc__
+
+ wrapped_function.print = (
+ """
+
+{signature}
+{doc_string}
+
+"""
+ ).format(
+ signature=wrapped_function.func_def,
+ doc_string=""" \"\"\"{0}\n \"\"\"""".format(doc_string)
+ if doc_string is not None
+ else "",
+ )
+
+ return wrapped_function
diff --git a/pandasai/smart_dataframe/__init__.py b/pandasai/smart_dataframe/__init__.py
index 4a024d3bc..28724c1e4 100644
--- a/pandasai/smart_dataframe/__init__.py
+++ b/pandasai/smart_dataframe/__init__.py
@@ -26,6 +26,7 @@
import pydantic
from pandasai.helpers.df_validator import DfValidator
+from pandasai.skills import skill
from ..smart_datalake import SmartDatalake
from ..schemas.df_config import Config
@@ -322,6 +323,12 @@ def add_middlewares(self, *middlewares: Optional[Middleware]):
"""
self.lake.add_middlewares(*middlewares)
+ def add_skills(self, *skills: List[skill]):
+ """
+ Add Skills to PandasAI
+ """
+ self.lake.add_skills(*skills)
+
def chat(self, query: str, output_type: Optional[str] = None):
"""
Run a query on the dataframe.
diff --git a/pandasai/smart_datalake/__init__.py b/pandasai/smart_datalake/__init__.py
index 672dbed7c..69bbeb299 100644
--- a/pandasai/smart_datalake/__init__.py
+++ b/pandasai/smart_datalake/__init__.py
@@ -21,6 +21,9 @@
import logging
import os
import traceback
+from pandasai.helpers.skills_manager import SkillsManager
+
+from pandasai.skills import skill
from pandasai.helpers.query_exec_tracker import QueryExecTracker
@@ -38,7 +41,7 @@
from ..prompts.correct_error_prompt import CorrectErrorPrompt
from ..prompts.generate_python_code import GeneratePythonCodePrompt
from typing import Union, List, Any, Type, Optional
-from ..helpers.code_manager import CodeManager
+from ..helpers.code_manager import CodeExecutionContext, CodeManager
from ..middlewares.base import Middleware
from ..helpers.df_info import DataFrameType
from ..helpers.path import find_project_root
@@ -51,11 +54,11 @@ class SmartDatalake:
_llm: LLM
_cache: Cache = None
_logger: Logger
- _start_time: float
_last_prompt_id: uuid.UUID
_conversation_id: uuid.UUID
_code_manager: CodeManager
_memory: Memory
+ _skills: SkillsManager
_instance: str
_query_exec_tracker: QueryExecTracker
@@ -104,6 +107,8 @@ def __init__(
logger=self.logger,
)
+ self._skills = SkillsManager()
+
if cache:
self._cache = cache
elif self._config.enable_cache:
@@ -210,6 +215,12 @@ def add_middlewares(self, *middlewares: Optional[Middleware]):
"""
self._code_manager.add_middlewares(*middlewares)
+ def add_skills(self, *skills: List[skill]):
+ """
+ Add Skills to PandasAI
+ """
+ self._skills.add_skills(*skills)
+
def _assign_prompt_id(self):
"""Assign a prompt ID"""
@@ -248,6 +259,11 @@ def _get_prompt(
prompt.set_var("dfs", self._dfs)
if "conversation" not in default_values:
prompt.set_var("conversation", self._memory.get_conversation())
+
+ # Adds the skills to prompt if exist else display nothing
+ skills_prompt = self._skills.prompt_display()
+ prompt.set_var("skills", skills_prompt if skills_prompt is not None else "")
+
for key, value in default_values.items():
prompt.set_var(key, value)
@@ -374,10 +390,10 @@ def chat(self, query: str, output_type: Optional[str] = None):
while retry_count < self._config.max_retries:
try:
# Execute the code
- result = self._query_exec_tracker.execute_func(
- self._code_manager.execute_code,
+ context = CodeExecutionContext(self._last_prompt_id, self._skills)
+ result = self._code_manager.execute_code(
code=code_to_run,
- prompt_id=self._last_prompt_id,
+ context=context,
)
break
diff --git a/tests/prompts/test_generate_python_code_prompt.py b/tests/prompts/test_generate_python_code_prompt.py
index 8cc3a0aa6..05ddb764e 100644
--- a/tests/prompts/test_generate_python_code_prompt.py
+++ b/tests/prompts/test_generate_python_code_prompt.py
@@ -51,6 +51,7 @@ def test_str_with_args(self, save_charts_path, output_type_hint):
prompt.set_var("conversation", "Question")
prompt.set_var("save_charts_path", save_charts_path)
prompt.set_var("output_type_hint", output_type_hint)
+ prompt.set_var("skills", "")
expected_prompt_content = f'''You are provided with the following pandas DataFrames:
@@ -108,6 +109,7 @@ def test_advanced_reasoning_prompt(self):
prompt.set_var("conversation", "Question")
prompt.set_var("save_charts_path", "")
prompt.set_var("output_type_hint", "")
+ prompt.set_var("skills", "")
expected_prompt_content = f'''You are provided with the following pandas DataFrames:
diff --git a/tests/skills/test_skills.py b/tests/skills/test_skills.py
new file mode 100644
index 000000000..ed979951c
--- /dev/null
+++ b/tests/skills/test_skills.py
@@ -0,0 +1,347 @@
+from typing import Optional
+from unittest.mock import MagicMock, Mock, patch
+import uuid
+import pandas as pd
+
+import pytest
+from pandasai.agent import Agent
+from pandasai.helpers.code_manager import CodeExecutionContext, CodeManager
+
+from pandasai.helpers.skills_manager import SkillsManager
+from pandasai.llm.fake import FakeLLM
+from pandasai.skills import skill
+from pandasai.smart_dataframe import SmartDataframe
+
+
+class TestSkills:
+ @pytest.fixture
+ def llm(self, output: Optional[str] = None):
+ return FakeLLM(output=output)
+
+ @pytest.fixture
+ def sample_df(self):
+ return pd.DataFrame(
+ {
+ "country": [
+ "United States",
+ "United Kingdom",
+ "France",
+ "Germany",
+ "Italy",
+ "Spain",
+ "Canada",
+ "Australia",
+ "Japan",
+ "China",
+ ],
+ "gdp": [
+ 19294482071552,
+ 2891615567872,
+ 2411255037952,
+ 3435817336832,
+ 1745433788416,
+ 1181205135360,
+ 1607402389504,
+ 1490967855104,
+ 4380756541440,
+ 14631844184064,
+ ],
+ "happiness_index": [
+ 6.94,
+ 7.16,
+ 6.66,
+ 7.07,
+ 6.38,
+ 6.4,
+ 7.23,
+ 7.22,
+ 5.87,
+ 5.12,
+ ],
+ }
+ )
+
+ @pytest.fixture
+ def smart_dataframe(self, llm, sample_df):
+ return SmartDataframe(sample_df, config={"llm": llm, "enable_cache": False})
+
+ @pytest.fixture
+ def code_manager(self, smart_dataframe: SmartDataframe):
+ return smart_dataframe.lake._code_manager
+
+ @pytest.fixture
+ def exec_context(self) -> MagicMock:
+ context = MagicMock(spec=CodeExecutionContext)
+ return context
+
+ @pytest.fixture
+ def agent(self, llm, sample_df):
+ return Agent(sample_df, config={"llm": llm, "enable_cache": False})
+
+ def test_add_skills(self):
+ skills_manager = SkillsManager()
+ skill1 = Mock(name="SkillA", print="SkillA Print")
+ skill2 = Mock(name="SkillB", print="SkillB Print")
+ skills_manager.add_skills(skill1, skill2)
+
+ # Ensure that skills are added
+ assert skill1 in skills_manager.skills
+ assert skill2 in skills_manager.skills
+
+ # Test that adding a skill with the same name raises an error
+ try:
+ skills_manager.add_skills(skill1)
+ except ValueError as e:
+ assert str(e) == f"Skill with name '{skill1.name}' already exists."
+ else:
+ assert False, "Expected ValueError"
+
+ def test_skill_exists(self):
+ skills_manager = SkillsManager()
+ skill1 = MagicMock()
+ skill2 = MagicMock()
+ skill1.name = "SkillA"
+ skill2.name = "SkillB"
+ skills_manager.add_skills(skill1, skill2)
+
+ assert skills_manager.skill_exists("SkillA")
+ assert skills_manager.skill_exists("SkillB")
+
+ # Test that a non-existing skill is not found
+ assert not skills_manager.skill_exists("SkillC")
+
+ def test_get_skill_by_func_name(self):
+ skills_manager = SkillsManager()
+ skill1 = Mock()
+ skill2 = Mock()
+ skill1.name = "SkillA"
+ skill2.name = "SkillB"
+ skills_manager.add_skills(skill1, skill2)
+
+ # Test that you can retrieve a skill by its function name
+ retrieved_skill = skills_manager.get_skill_by_func_name("SkillA")
+ assert retrieved_skill == skill1
+
+ # Test that a non-existing skill returns None
+ retrieved_skill = skills_manager.get_skill_by_func_name("SkillC")
+ assert retrieved_skill is None
+
+ def test_add_used_skill(self):
+ skills_manager = SkillsManager()
+ skill1 = Mock()
+ skill2 = Mock()
+ skill1.name = "SkillA"
+ skill2.name = "SkillB"
+ skills_manager.add_skills(skill1, skill2)
+
+ # Test adding used skills
+ skills_manager.add_used_skill("SkillA")
+ skills_manager.add_used_skill("SkillB")
+
+ # Ensure used skills are added to the used_skills list
+ assert "SkillA" in skills_manager.used_skills
+ assert "SkillB" in skills_manager.used_skills
+
+ def test_prompt_display(self):
+ skills_manager = SkillsManager()
+ skill1 = Mock()
+ skill2 = Mock()
+ skill1.name = "SkillA"
+ skill2.name = "SkillB"
+ skill1.print = "SkillA"
+ skill2.print = "SkillB"
+ skills_manager.add_skills(skill1, skill2)
+
+ # Test prompt_display method when skills exist
+ prompt = skills_manager.prompt_display()
+ assert "You can also use the following functions" in prompt
+
+ # Test prompt_display method when no skills exist
+ skills_manager._skills = []
+ prompt = skills_manager.prompt_display()
+ assert prompt is None
+
+ @patch("pandasai.skills.inspect.signature", return_value="(a, b, c)")
+ def test_skill_decorator(self, mock_inspect_signature):
+ # Define skills using the decorator
+ @skill
+ def skill_a(*args, **kwargs):
+ return "SkillA Result"
+
+ @skill
+ def skill_b(*args, **kwargs):
+ return "SkillB Result"
+
+ # Test the wrapped functions
+ assert skill_a() == "SkillA Result"
+ assert skill_b() == "SkillB Result"
+
+ # Test the additional attributes added by the decorator
+ assert skill_a.name == "skill_a"
+ assert skill_b.name == "skill_b"
+
+ assert skill_a.func_def == "def pandasai.skills.skill_a(a, b, c)"
+ assert skill_b.func_def == "def pandasai.skills.skill_b(a, b, c)"
+
+ assert (
+ skill_a.print
+ == """\n\ndef pandasai.skills.skill_a(a, b, c)\n\n\n""" # noqa: E501
+ )
+ assert (
+ skill_b.print
+ == """\n\ndef pandasai.skills.skill_b(a, b, c)\n\n\n""" # noqa: E501
+ )
+
+ @patch("pandasai.skills.inspect.signature", return_value="(a, b, c)")
+ def test_skill_decorator_test_codc(self, llm):
+ df = pd.DataFrame({"country": []})
+ df = SmartDataframe(df, config={"llm": llm, "enable_cache": False})
+
+ # Define skills using the decorator
+ @skill
+ def plot_salaries(*args, **kwargs):
+ """
+ Test skill A
+ Args:
+ arg(str)
+ """
+ return "SkillA Result"
+
+ function_def = """
+ Test skill A
+ Args:
+ arg(str)
+""" # noqa: E501
+
+ assert function_def in plot_salaries.print
+
+ def test_add_skills_with_agent(self, agent: Agent):
+ # Define skills using the decorator
+ @skill
+ def skill_a(*args, **kwargs):
+ return "SkillA Result"
+
+ @skill
+ def skill_b(*args, **kwargs):
+ return "SkillB Result"
+
+ agent.add_skills(skill_a)
+ assert len(agent._lake._skills.skills) == 1
+
+ agent._lake._skills._skills = []
+ agent.add_skills(skill_a, skill_b)
+ assert len(agent._lake._skills.skills) == 2
+
+ def test_add_skills_with_smartDataframe(self, smart_dataframe: SmartDataframe):
+ # Define skills using the decorator
+ @skill
+ def skill_a(*args, **kwargs):
+ return "SkillA Result"
+
+ @skill
+ def skill_b(*args, **kwargs):
+ return "SkillB Result"
+
+ smart_dataframe.add_skills(skill_a)
+ assert len(smart_dataframe._lake._skills.skills) == 1
+
+ smart_dataframe._lake._skills._skills = []
+ smart_dataframe.add_skills(skill_a, skill_b)
+ assert len(smart_dataframe._lake._skills.skills) == 2
+
+ def test_run_prompt(self, llm):
+ df = pd.DataFrame({"country": []})
+ df = SmartDataframe(df, config={"llm": llm, "enable_cache": False})
+
+ function_def = """
+
+def pandasai.skills.plot_salaries(merged_df: pandas.core.frame.DataFrame) -> str
+
+
+""" # noqa: E501
+
+ @skill
+ def plot_salaries(merged_df: pd.DataFrame) -> str:
+ import matplotlib.pyplot as plt
+
+ plt.bar(merged_df["Name"], merged_df["Salary"])
+ plt.xlabel("Employee Name")
+ plt.ylabel("Salary")
+ plt.title("Employee Salaries")
+ plt.xticks(rotation=45)
+ plt.savefig("temp_chart.png")
+ plt.close()
+
+ df.add_skills(plot_salaries)
+
+ df.chat("How many countries are in the dataframe?")
+ last_prompt = df.last_prompt
+ assert function_def in last_prompt
+
+ def test_run_prompt_agent(self, agent):
+ function_def = """
+
+def pandasai.skills.plot_salaries(merged_df: pandas.core.frame.DataFrame) -> str
+
+
+""" # noqa: E501
+
+ @skill
+ def plot_salaries(merged_df: pd.DataFrame) -> str:
+ import matplotlib.pyplot as plt
+
+ plt.bar(merged_df["Name"], merged_df["Salary"])
+ plt.xlabel("Employee Name")
+ plt.ylabel("Salary")
+ plt.title("Employee Salaries")
+ plt.xticks(rotation=45)
+ plt.savefig("temp_chart.png")
+ plt.close()
+
+ agent.add_skills(plot_salaries)
+
+ agent.chat("How many countries are in the dataframe?")
+ last_prompt = agent._lake.last_prompt
+
+ assert function_def in last_prompt
+
+ def test_run_prompt_without_skills(self, agent):
+ agent.chat("How many countries are in the dataframe?")
+
+ last_prompt = agent._lake.last_prompt
+
+ assert "" not in last_prompt
+ assert "" not in last_prompt
+ assert (
+ "You can also use the following functions, if relevant:" not in last_prompt
+ )
+
+ def test_code_exec_with_skills_no_use(
+ self, code_manager: CodeManager, exec_context: MagicMock
+ ):
+ code = """def analyze_data(dfs):
+ return {'type': 'number', 'value': 1 + 1}"""
+ skill1 = MagicMock()
+ skill1.name = "SkillA"
+ exec_context._skills_manager._skills = [skill1]
+ code_manager.execute_code(code, exec_context)
+ assert len(exec_context._skills_manager.used_skills) == 0
+
+ def test_code_exec_with_skills(self, code_manager: CodeManager):
+ code = """def analyze_data(dfs):
+ plot_salaries()
+ return {'type': 'number', 'value': 1 + 1}"""
+
+ @skill
+ def plot_salaries() -> str:
+ return "plot_salaries"
+
+ code_manager._middlewares = []
+
+ sm = SkillsManager()
+ sm.add_skills(plot_salaries)
+ exec_context = CodeExecutionContext(uuid.uuid4(), sm)
+ code_manager.execute_code(code, exec_context)
+
+ assert len(exec_context._skills_manager.used_skills) == 1
+ assert exec_context._skills_manager.used_skills[0] == "plot_salaries"
diff --git a/tests/test_codemanager.py b/tests/test_codemanager.py
index b5c39933a..0df494429 100644
--- a/tests/test_codemanager.py
+++ b/tests/test_codemanager.py
@@ -1,7 +1,6 @@
"""Unit tests for the CodeManager class"""
-import uuid
from typing import Optional
-from unittest.mock import Mock, patch
+from unittest.mock import MagicMock, Mock, patch
import pandas as pd
import pytest
@@ -11,7 +10,7 @@
from pandasai.smart_dataframe import SmartDataframe
-from pandasai.helpers.code_manager import CodeManager
+from pandasai.helpers.code_manager import CodeExecutionContext, CodeManager
class TestCodeManager:
@@ -72,23 +71,33 @@ def smart_dataframe(self, llm, sample_df):
def code_manager(self, smart_dataframe: SmartDataframe):
return smart_dataframe.lake._code_manager
- def test_run_code_for_calculations(self, code_manager: CodeManager):
+ @pytest.fixture
+ def exec_context(self) -> MagicMock:
+ context = MagicMock(spec=CodeExecutionContext)
+ return context
+
+ def test_run_code_for_calculations(
+ self, code_manager: CodeManager, exec_context: MagicMock
+ ):
code = """def analyze_data(dfs):
return {'type': 'number', 'value': 1 + 1}"""
-
- assert code_manager.execute_code(code, uuid.uuid4())["value"] == 2
+ assert code_manager.execute_code(code, exec_context)["value"] == 2
assert code_manager.last_code_executed == code
- def test_run_code_invalid_code(self, code_manager: CodeManager):
+ def test_run_code_invalid_code(
+ self, code_manager: CodeManager, exec_context: MagicMock
+ ):
with pytest.raises(Exception):
# noinspection PyStatementEffect
- code_manager.execute_code("1+ ", uuid.uuid4())["value"]
+ code_manager.execute_code("1+ ", exec_context)["value"]
- def test_clean_code_remove_builtins(self, code_manager: CodeManager):
+ def test_clean_code_remove_builtins(
+ self, code_manager: CodeManager, exec_context: MagicMock
+ ):
builtins_code = """import set
def analyze_data(dfs):
return {'type': 'number', 'value': set([1, 2, 3])}"""
- assert code_manager.execute_code(builtins_code, uuid.uuid4())["value"] == {
+ assert code_manager.execute_code(builtins_code, exec_context)["value"] == {
1,
2,
3,
@@ -99,44 +108,57 @@ def analyze_data(dfs):
return {'type': 'number', 'value': set([1, 2, 3])}"""
)
- def test_clean_code_removes_jailbreak_code(self, code_manager: CodeManager):
+ def test_clean_code_removes_jailbreak_code(
+ self, code_manager: CodeManager, exec_context: MagicMock
+ ):
malicious_code = """def analyze_data(dfs):
__builtins__['str'].__class__.__mro__[-1].__subclasses__()[140].__init__.__globals__['system']('ls')
print('hello world')"""
assert (
- code_manager._clean_code(malicious_code)
+ code_manager._clean_code(malicious_code, exec_context)
== """def analyze_data(dfs):
print('hello world')"""
)
- def test_clean_code_remove_environment_defaults(self, code_manager: CodeManager):
+ def test_clean_code_remove_environment_defaults(
+ self, code_manager: CodeManager, exec_context: MagicMock
+ ):
pandas_code = """import pandas as pd
print('hello world')
"""
- assert code_manager._clean_code(pandas_code) == "print('hello world')"
+ assert (
+ code_manager._clean_code(pandas_code, exec_context)
+ == "print('hello world')"
+ )
- def test_clean_code_whitelist_import(self, code_manager: CodeManager):
+ def test_clean_code_whitelist_import(
+ self, code_manager: CodeManager, exec_context: MagicMock
+ ):
"""Test that an installed whitelisted library is added to the environment."""
safe_code = """
import numpy as np
np.array()
"""
- assert code_manager._clean_code(safe_code) == "np.array()"
+ assert code_manager._clean_code(safe_code, exec_context) == "np.array()"
- def test_clean_code_raise_bad_import_error(self, code_manager: CodeManager):
+ def test_clean_code_raise_bad_import_error(
+ self, code_manager: CodeManager, exec_context: MagicMock
+ ):
malicious_code = """
import os
print(os.listdir())
"""
with pytest.raises(BadImportError):
- code_manager.execute_code(malicious_code, uuid.uuid4())
+ code_manager.execute_code(malicious_code, exec_context)
- def test_remove_dfs_overwrites(self, code_manager: CodeManager):
+ def test_remove_dfs_overwrites(
+ self, code_manager: CodeManager, exec_context: MagicMock
+ ):
hallucinated_code = """def analyze_data(dfs):
dfs = [pd.DataFrame([1,2,3])]
print(dfs)"""
assert (
- code_manager._clean_code(hallucinated_code)
+ code_manager._clean_code(hallucinated_code, exec_context)
== """def analyze_data(dfs):
print(dfs)"""
)
@@ -157,7 +179,9 @@ def test_exception_handling(
)
assert smart_dataframe.last_error == "No code found in the answer."
- def test_custom_whitelisted_dependencies(self, code_manager: CodeManager, llm):
+ def test_custom_whitelisted_dependencies(
+ self, code_manager: CodeManager, llm, exec_context: MagicMock
+ ):
code = """
import my_custom_library
def analyze_data(dfs: list):
@@ -166,11 +190,11 @@ def analyze_data(dfs: list):
llm._output = code
with pytest.raises(BadImportError):
- code_manager._clean_code(code)
+ code_manager._clean_code(code, exec_context)
code_manager._config.custom_whitelisted_dependencies = ["my_custom_library"]
assert (
- code_manager._clean_code(code)
+ code_manager._clean_code(code, exec_context)
== """def analyze_data(dfs: list):
my_custom_library.do_something()"""
)