diff --git a/examples/sql_direct_config.py b/examples/sql_direct_config.py new file mode 100644 index 000000000..97b2abef8 --- /dev/null +++ b/examples/sql_direct_config.py @@ -0,0 +1,59 @@ +"""Example of using PandasAI with a CSV file.""" + +from pandasai import SmartDatalake +from pandasai.llm import OpenAI +from pandasai.connectors import PostgreSQLConnector +from pandasai.smart_dataframe import SmartDataframe + + +# With a PostgreSQL database +order = PostgreSQLConnector( + config={ + "host": "localhost", + "port": 5432, + "database": "testdb", + "username": "postgres", + "password": "123456", + "table": "orders", + } +) + +order_details = PostgreSQLConnector( + config={ + "host": "localhost", + "port": 5432, + "database": "testdb", + "username": "postgres", + "password": "123456", + "table": "order_details", + } +) + +products = PostgreSQLConnector( + config={ + "host": "localhost", + "port": 5432, + "database": "testdb", + "username": "postgres", + "password": "123456", + "table": "products", + } +) + + +llm = OpenAI("OPEN_API_KEY") + + +order_details_smart_df = SmartDataframe( + order_details, + config={"llm": llm, "direct_sql": True}, + description="Contain user order details", +) + + +df = SmartDatalake( + [order_details_smart_df, order, products], + config={"llm": llm, "direct_sql": True}, +) +response = df.chat("return orders with count of distinct products") +print(response) diff --git a/mkdocs.yml b/mkdocs.yml index 0e0c6df48..75ced15e0 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -42,7 +42,8 @@ nav: - Documents Building: building_docs.md - License: license.md extra: - version: "1.4.3" + version: "1.4.4" + plugins: - search - mkdocstrings: diff --git a/pandasai/assets/prompt_templates/default_instructions.tmpl b/pandasai/assets/prompt_templates/default_instructions.tmpl new file mode 100644 index 000000000..f72542e20 --- /dev/null +++ b/pandasai/assets/prompt_templates/default_instructions.tmpl @@ -0,0 +1,5 @@ +Analyze the data, using the provided dataframes (`dfs`). + 1. Prepare: Preprocessing and cleaning data if necessary + 2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.) + 3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.) + {viz_library_type} \ No newline at end of file diff --git a/pandasai/assets/prompt_templates/direct_sql_connector.tmpl b/pandasai/assets/prompt_templates/direct_sql_connector.tmpl new file mode 100644 index 000000000..71227310e --- /dev/null +++ b/pandasai/assets/prompt_templates/direct_sql_connector.tmpl @@ -0,0 +1,39 @@ +You are provided with the following samples of sql tables data: + + +{tables} + + + +{conversation} + + +You are provided with following function that executes the sql query, + +def execute_sql_query(sql_query: str) -> pd.Dataframe +"""his method connect to the database executes the sql query and returns the dataframe""" + + +This is the initial python function. Do not change the params. + +```python +# TODO import all the dependencies required +import pandas as pd + +def analyze_data() -> dict: + """ + Analyze the data, using the provided dataframes (`dfs`). + 1. Prepare: generate sql query to get data for analysis (grouping, filtering, aggregating, etc.) + 2. Process: execute the query using execute method available to you which returns dataframe + 3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.) + {viz_library_type} + At the end, return a dictionary of: + {output_type_hint} + """ +``` + +Take a deep breath and reason step-by-step. Act as a senior data analyst. +In the answer, you must never write the "technical" names of the tables. +Based on the last message in the conversation: + +- return the updated analyze_data function wrapped within `python ` \ No newline at end of file diff --git a/pandasai/assets/prompt_templates/generate_python_code.tmpl b/pandasai/assets/prompt_templates/generate_python_code.tmpl index 28e850b9c..b0c337e44 100644 --- a/pandasai/assets/prompt_templates/generate_python_code.tmpl +++ b/pandasai/assets/prompt_templates/generate_python_code.tmpl @@ -6,8 +6,6 @@ You are provided with the following pandas DataFrames: {conversation} -{viz_library_type} - This is the initial python function. Do not change the params. Given the context, use the right dataframes. ```python {current_code} diff --git a/pandasai/assets/prompt_templates/viz_library.tmpl b/pandasai/assets/prompt_templates/viz_library.tmpl new file mode 100644 index 000000000..ee3b7ed25 --- /dev/null +++ b/pandasai/assets/prompt_templates/viz_library.tmpl @@ -0,0 +1 @@ +If the user requests to create a chart, utilize the Python {library} library to generate high-quality graphics that will be saved directly to a file. \ No newline at end of file diff --git a/pandasai/connectors/databricks.py b/pandasai/connectors/databricks.py index 0f70b980b..8ab852cb8 100644 --- a/pandasai/connectors/databricks.py +++ b/pandasai/connectors/databricks.py @@ -63,3 +63,20 @@ def __repr__(self): f"host={self._config.host} port={self._config.port} " f"database={self._config.database} httpPath={str(self._config.httpPath)}" ) + + def equals(self, other): + if isinstance(other, self.__class__): + return ( + self._config.dialect, + self._config.token, + self._config.host, + self._config.port, + self._config.httpPath, + ) == ( + other._config.dialect, + other._config.token, + other._config.host, + other._config.port, + other._config.httpPath, + ) + return False diff --git a/pandasai/connectors/snowflake.py b/pandasai/connectors/snowflake.py index 7120ec45c..abd53a316 100644 --- a/pandasai/connectors/snowflake.py +++ b/pandasai/connectors/snowflake.py @@ -90,3 +90,18 @@ def __repr__(self): f"database={self._config.database} schema={str(self._config.dbSchema)} " f"table={self._config.table}>" ) + + def equals(self, other): + if isinstance(other, self.__class__): + return ( + self._config.dialect, + self._config.account, + self._config.username, + self._config.password, + ) == ( + other._config.dialect, + other._config.account, + other._config.username, + other._config.password, + ) + return False diff --git a/pandasai/connectors/sql.py b/pandasai/connectors/sql.py index 0e8c9715d..5d4443563 100644 --- a/pandasai/connectors/sql.py +++ b/pandasai/connectors/sql.py @@ -5,6 +5,8 @@ import re import os import pandas as pd + +from pandasai.exceptions import MaliciousQueryError from .base import BaseConnector, SQLConnectorConfig, SqliteConnectorConfig from .base import BaseConnectorConfig from sqlalchemy import create_engine, text, select, asc @@ -360,6 +362,46 @@ def column_hash(self): def fallback_name(self): return self._config.table + def equals(self, other): + if isinstance(other, self.__class__): + return ( + self._config.dialect, + self._config.driver, + self._config.host, + self._config.port, + self._config.username, + self._config.password, + ) == ( + other._config.dialect, + other._config.driver, + other._config.host, + other._config.port, + other._config.username, + other._config.password, + ) + return False + + def _is_sql_query_safe(self, query: str): + infected_keywords = [ + r"\bINSERT\b", + r"\bUPDATE\b", + r"\bDELETE\b", + r"\bDROP\b", + r"\bEXEC\b", + r"\bALTER\b", + r"\bCREATE\b", + ] + + return not any( + re.search(keyword, query, re.IGNORECASE) for keyword in infected_keywords + ) + + def execute_direct_sql_query(self, sql_query): + if not self._is_sql_query_safe(sql_query): + raise MaliciousQueryError("Malicious query is generated in code") + + return pd.read_sql(sql_query, self._connection) + class SqliteConnector(SQLConnector): """ diff --git a/pandasai/exceptions.py b/pandasai/exceptions.py index 39158c792..c90df7219 100644 --- a/pandasai/exceptions.py +++ b/pandasai/exceptions.py @@ -155,6 +155,7 @@ class UnSupportedLogicUnit(Exception): Exception (Exception): UnSupportedLogicUnit """ + class InvalidWorkspacePathError(Exception): """ Raised when the environment variable of workspace exist but path is invalid @@ -162,3 +163,19 @@ class InvalidWorkspacePathError(Exception): Args: Exception (Exception): InvalidWorkspacePathError """ + + +class InvalidConfigError(Exception): + """ + Raised when config value is not appliable + Args: + Exception (Exception): InvalidConfigError + """ + + +class MaliciousQueryError(Exception): + """ + Raise error if malicious query is generated + Args: + Exception (Excpetion): MaliciousQueryError + """ diff --git a/pandasai/helpers/code_manager.py b/pandasai/helpers/code_manager.py index edf7aa0fd..f15f2ecc8 100644 --- a/pandasai/helpers/code_manager.py +++ b/pandasai/helpers/code_manager.py @@ -28,9 +28,15 @@ class CodeExecutionContext: _prompt_id: uuid.UUID = None + _can_direct_sql: bool = False _skills_manager: SkillsManager = None - def __init__(self, prompt_id: uuid.UUID, skills_manager: SkillsManager): + def __init__( + self, + prompt_id: uuid.UUID, + skills_manager: SkillsManager, + _can_direct_sql: bool = False, + ): """ Additional Context for code execution Args: @@ -39,6 +45,7 @@ def __init__(self, prompt_id: uuid.UUID, skills_manager: SkillsManager): """ self._skills_manager = skills_manager self._prompt_id = prompt_id + self._can_direct_sql = _can_direct_sql @property def prompt_id(self): @@ -48,6 +55,10 @@ def prompt_id(self): def skills_manager(self): return self._skills_manager + @property + def can_direct_sql(self): + return self._can_direct_sql + class CodeManager: _dfs: List @@ -283,6 +294,10 @@ def execute_code(self, code: str, context: CodeExecutionContext) -> Any: analyze_data = environment.get("analyze_data") + if context.can_direct_sql: + environment["execute_sql_query"] = self._dfs[0].get_query_exec_func() + return analyze_data() + return analyze_data(self._get_originals(dfs)) def _get_samples(self, dfs): diff --git a/pandasai/helpers/viz_library_types/_viz_library_types.py b/pandasai/helpers/viz_library_types/_viz_library_types.py index 87b290a8f..3c9ae66e3 100644 --- a/pandasai/helpers/viz_library_types/_viz_library_types.py +++ b/pandasai/helpers/viz_library_types/_viz_library_types.py @@ -1,13 +1,12 @@ from abc import abstractmethod, ABC from typing import Any, Iterable +from pandasai.prompts.generate_python_code import VizLibraryPrompt class BaseVizLibraryType(ABC): @property def template_hint(self) -> str: - return f"""When a user requests to create a chart, utilize the Python -{self.name} library to generate high-quality graphics that will be saved -directly to a file.""" + return VizLibraryPrompt(library=self.name) @property @abstractmethod diff --git a/pandasai/prompts/base.py b/pandasai/prompts/base.py index 7a692a18e..1015ea353 100644 --- a/pandasai/prompts/base.py +++ b/pandasai/prompts/base.py @@ -2,6 +2,7 @@ In order to better handle the instructions, this prompt module is written. """ from abc import ABC, abstractmethod +import string class AbstractPrompt(ABC): @@ -92,12 +93,11 @@ def to_string(self): prompt_args = {} for key, value in self._args.items(): if isinstance(value, AbstractPrompt): + args = [ + arg[1] for arg in string.Formatter().parse(value.template) if arg[1] + ] value.set_vars( - { - k: v - for k, v in self._args.items() - if k != key and not isinstance(v, AbstractPrompt) - } + {k: v for k, v in self._args.items() if k != key and k in args} ) prompt_args[key] = value.to_string() else: diff --git a/pandasai/prompts/direct_sql_prompt.py b/pandasai/prompts/direct_sql_prompt.py new file mode 100644 index 000000000..f05e40bb0 --- /dev/null +++ b/pandasai/prompts/direct_sql_prompt.py @@ -0,0 +1,40 @@ +""" Prompt to explain code generation by the LLM +The previous conversation we had + + +{conversation} + + +Based on the last conversation you generated the following code: + + +{code} + + +Explain how you came up with code for non-technical people without +mentioning technical details or mentioning the libraries used? + +""" +from .file_based_prompt import FileBasedPrompt + + +class DirectSQLPrompt(FileBasedPrompt): + """Prompt to explain code generation by the LLM""" + + _path_to_template = "assets/prompt_templates/direct_sql_connector.tmpl" + + def _prepare_tables_data(self, tables): + tables_join = [] + for table in tables: + table_description_tag = ( + f' description="{table.table_description}"' + if table.table_description is not None + else "" + ) + table_head_tag = f'' + table = f"{table_head_tag}\n{table.head_csv}\n
" + tables_join.append(table) + return "\n\n".join(tables_join) + + def setup(self, tables) -> None: + self.set_var("tables", self._prepare_tables_data(tables)) diff --git a/pandasai/prompts/generate_python_code.py b/pandasai/prompts/generate_python_code.py index 7ee7aba0e..a54d4a794 100644 --- a/pandasai/prompts/generate_python_code.py +++ b/pandasai/prompts/generate_python_code.py @@ -26,18 +26,30 @@ class CurrentCodePrompt(FileBasedPrompt): _path_to_template = "assets/prompt_templates/current_code.tmpl" +class DefaultInstructionsPrompt(FileBasedPrompt): + """The default instructions""" + + _path_to_template = "assets/prompt_templates/default_instructions.tmpl" + + class AdvancedReasoningPrompt(FileBasedPrompt): - """The current code""" + """The advanced reasoning instructions""" _path_to_template = "assets/prompt_templates/advanced_reasoning.tmpl" class SimpleReasoningPrompt(FileBasedPrompt): - """The current code""" + """The simple reasoning instructions""" _path_to_template = "assets/prompt_templates/simple_reasoning.tmpl" +class VizLibraryPrompt(FileBasedPrompt): + """Provide information about the visualization library""" + + _path_to_template = "assets/prompt_templates/viz_library.tmpl" + + class GeneratePythonCodePrompt(FileBasedPrompt): """Prompt to generate Python code""" @@ -45,14 +57,11 @@ class GeneratePythonCodePrompt(FileBasedPrompt): def setup(self, **kwargs) -> None: if "custom_instructions" in kwargs: - self._set_instructions(kwargs["custom_instructions"]) - else: - self._set_instructions( - """Analyze the data, using the provided dataframes (`dfs`). -1. Prepare: Preprocessing and cleaning data if necessary -2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.) -3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.)""" # noqa: E501 + self.set_var( + "instructions", self._format_instructions(kwargs["custom_instructions"]) ) + else: + self.set_var("instructions", DefaultInstructionsPrompt()) if "current_code" in kwargs: self.set_var("current_code", kwargs["current_code"]) @@ -70,8 +79,7 @@ def on_prompt_generation(self) -> None: else: self.set_var("reasoning", SimpleReasoningPrompt()) - def _set_instructions(self, instructions: str): + def _format_instructions(self, instructions: str): lines = instructions.split("\n") indented_lines = [f" {line}" for line in lines[1:]] - result = "\n".join([lines[0]] + indented_lines) - self.set_var("instructions", result) + return "\n".join([lines[0]] + indented_lines) diff --git a/pandasai/schemas/df_config.py b/pandasai/schemas/df_config.py index 2cd1f312d..fa100a90e 100644 --- a/pandasai/schemas/df_config.py +++ b/pandasai/schemas/df_config.py @@ -30,11 +30,11 @@ class Config(BaseModel): max_retries: int = 3 middlewares: List[Middleware] = Field(default_factory=list) callback: Optional[BaseCallback] = None - lazy_load_connector: bool = True response_parser: Type[ResponseParser] = None llm: Any = None data_viz_library: Optional[VisualizationLibrary] = None log_server: LogServerConfig = None + direct_sql: bool = False class Config: arbitrary_types_allowed = True diff --git a/pandasai/smart_dataframe/__init__.py b/pandasai/smart_dataframe/__init__.py index 278355a1f..83b9eff75 100644 --- a/pandasai/smart_dataframe/__init__.py +++ b/pandasai/smart_dataframe/__init__.py @@ -726,6 +726,19 @@ def __repr__(self): def __len__(self): return len(self.dataframe) + def __eq__(self, other): + if isinstance(other, self.__class__): + if self._core.has_connector and other._core.has_connector: + return self._core.connector.equals(other._core.connector) + + return False + + def is_connector(self): + return self._core.has_connector + + def get_query_exec_func(self): + return self._core.connector.execute_direct_sql_query + def load_smartdataframes( dfs: List[Union[DataFrameType, Any]], config: Config diff --git a/pandasai/smart_datalake/__init__.py b/pandasai/smart_datalake/__init__.py index b480f1e4a..111eaf155 100644 --- a/pandasai/smart_datalake/__init__.py +++ b/pandasai/smart_datalake/__init__.py @@ -23,11 +23,9 @@ import traceback from pandasai.constants import DEFAULT_CHART_DIRECTORY from pandasai.helpers.skills_manager import SkillsManager - +from pandasai.prompts.direct_sql_prompt import DirectSQLPrompt from pandasai.skills import skill - from pandasai.helpers.query_exec_tracker import QueryExecTracker - from ..helpers.output_types import output_type_factory from ..helpers.viz_library_types import viz_lib_type_factory from pandasai.responses.context import Context @@ -42,13 +40,13 @@ from ..prompts.base import AbstractPrompt from ..prompts.correct_error_prompt import CorrectErrorPrompt from ..prompts.generate_python_code import GeneratePythonCodePrompt -from typing import Union, List, Any, Type, Optional +from typing import Union, List, Any, Optional from ..helpers.code_manager import CodeExecutionContext, CodeManager from ..middlewares.base import Middleware from ..helpers.df_info import DataFrameType from ..helpers.path import find_project_root from ..helpers.viz_library_types.base import VisualizationLibrary -from ..exceptions import AdvancedReasoningDisabledError +from ..exceptions import AdvancedReasoningDisabledError, InvalidConfigError class SmartDatalake: @@ -64,6 +62,7 @@ class SmartDatalake: _skills: SkillsManager _instance: str _query_exec_tracker: QueryExecTracker + _can_direct_sql: bool _last_code_generated: str = None _last_reasoning: str = None @@ -133,6 +132,9 @@ def __init__( server_config=self._config.log_server, ) + # Checks if direct sql config set they all belong to same sql connector type + self._can_direct_sql = self._validate_direct_sql(self._dfs) + def set_instance_type(self, type: str): self._instance = type @@ -247,6 +249,37 @@ def _load_data_viz_library(self, data_viz_library: str): if data_viz_library in (item.value for item in VisualizationLibrary): self._data_viz_library = data_viz_library + def _validate_direct_sql(self, dfs: List) -> None: + """ + Raises error if they don't belong sqlconnector or have different credentials + Args: + dfs (List[SmartDataframe]): list of SmartDataframes + + Raises: + InvalidConfigError: Raise Error in case of config is set but criteria is not met + """ + + if self._config.direct_sql and all(df.is_connector() for df in dfs): + if all(df == dfs[0] for df in dfs): + return True + else: + raise InvalidConfigError( + "Direct requires all SQLConnector and they belong to same datasource " + "and have same credentials" + ) + return False + + def _get_chat_prompt(self): + key = "direct_sql_prompt" if self._config.direct_sql else "generate_python_code" + return ( + key, + ( + DirectSQLPrompt(tables=self._dfs) + if self._config.direct_sql + else GeneratePythonCodePrompt() + ), + ) + def add_middlewares(self, *middlewares: Optional[Middleware]): """ Add middlewares to PandasAI instance. @@ -273,7 +306,7 @@ def _assign_prompt_id(self): def _get_prompt( self, key: str, - default_prompt: Type[AbstractPrompt], + default_prompt: AbstractPrompt, default_values: Optional[dict] = None, ) -> AbstractPrompt: """ @@ -292,7 +325,7 @@ def _get_prompt( default_values = {} custom_prompt = self._config.custom_prompts.get(key) - prompt = custom_prompt or default_prompt() + prompt = custom_prompt or default_prompt # set default values for the prompt prompt.set_config(self._config) @@ -325,6 +358,10 @@ def _get_cache_key(self) -> str: hash = df.column_hash() cache_key += str(hash) + # direct flag to separate out caching for different codegen + if self._config.direct_sql: + cache_key += "direct_sql" + return cache_key def chat(self, query: str, output_type: Optional[str] = None): @@ -395,11 +432,12 @@ def chat(self, query: str, output_type: Optional[str] = None): ): default_values["current_code"] = self._last_code_generated + prompt_key, prompt = self._get_chat_prompt() generate_python_code_instruction = ( self._query_exec_tracker.execute_func( self._get_prompt, - "generate_python_code", - default_prompt=GeneratePythonCodePrompt, + key=prompt_key, + default_prompt=prompt, default_values=default_values, ) ) @@ -432,7 +470,9 @@ def chat(self, query: str, output_type: Optional[str] = None): while retry_count < self._config.max_retries: try: # Execute the code - context = CodeExecutionContext(self._last_prompt_id, self._skills) + context = CodeExecutionContext( + self._last_prompt_id, self._skills, self._can_direct_sql + ) result = self._code_manager.execute_code( code=code_to_run, context=context, @@ -553,7 +593,7 @@ def _retry_run_code(self, code: str, e: Exception) -> List: } error_correcting_instruction = self._get_prompt( "correct_error", - default_prompt=CorrectErrorPrompt, + default_prompt=CorrectErrorPrompt(), default_values=default_values, ) diff --git a/poetry.lock b/poetry.lock index a46f37b0a..108e30fae 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -532,7 +532,6 @@ files = [ {file = "contourpy-1.1.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:18a64814ae7bce73925131381603fff0116e2df25230dfc80d6d690aa6e20b37"}, {file = "contourpy-1.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90c81f22b4f572f8a2110b0b741bb64e5a6427e0a198b2cdc1fbaf85f352a3aa"}, {file = "contourpy-1.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:53cc3a40635abedbec7f1bde60f8c189c49e84ac180c665f2cd7c162cc454baa"}, - {file = "contourpy-1.1.0-cp310-cp310-win32.whl", hash = "sha256:9b2dd2ca3ac561aceef4c7c13ba654aaa404cf885b187427760d7f7d4c57cff8"}, {file = "contourpy-1.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:1f795597073b09d631782e7245016a4323cf1cf0b4e06eef7ea6627e06a37ff2"}, {file = "contourpy-1.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0b7b04ed0961647691cfe5d82115dd072af7ce8846d31a5fac6c142dcce8b882"}, {file = "contourpy-1.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:27bc79200c742f9746d7dd51a734ee326a292d77e7d94c8af6e08d1e6c15d545"}, @@ -541,7 +540,6 @@ files = [ {file = "contourpy-1.1.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e5cec36c5090e75a9ac9dbd0ff4a8cf7cecd60f1b6dc23a374c7d980a1cd710e"}, {file = "contourpy-1.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f0cbd657e9bde94cd0e33aa7df94fb73c1ab7799378d3b3f902eb8eb2e04a3a"}, {file = "contourpy-1.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:181cbace49874f4358e2929aaf7ba84006acb76694102e88dd15af861996c16e"}, - {file = "contourpy-1.1.0-cp311-cp311-win32.whl", hash = "sha256:edb989d31065b1acef3828a3688f88b2abb799a7db891c9e282df5ec7e46221b"}, {file = "contourpy-1.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:fb3b7d9e6243bfa1efb93ccfe64ec610d85cfe5aec2c25f97fbbd2e58b531256"}, {file = "contourpy-1.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:bcb41692aa09aeb19c7c213411854402f29f6613845ad2453d30bf421fe68fed"}, {file = "contourpy-1.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5d123a5bc63cd34c27ff9c7ac1cd978909e9c71da12e05be0231c608048bb2ae"}, @@ -550,7 +548,6 @@ files = [ {file = "contourpy-1.1.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:317267d915490d1e84577924bd61ba71bf8681a30e0d6c545f577363157e5e94"}, {file = "contourpy-1.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d551f3a442655f3dcc1285723f9acd646ca5858834efeab4598d706206b09c9f"}, {file = "contourpy-1.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:e7a117ce7df5a938fe035cad481b0189049e8d92433b4b33aa7fc609344aafa1"}, - {file = "contourpy-1.1.0-cp38-cp38-win32.whl", hash = "sha256:108dfb5b3e731046a96c60bdc46a1a0ebee0760418951abecbe0fc07b5b93b27"}, {file = "contourpy-1.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:d4f26b25b4f86087e7d75e63212756c38546e70f2a92d2be44f80114826e1cd4"}, {file = "contourpy-1.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bc00bb4225d57bff7ebb634646c0ee2a1298402ec10a5fe7af79df9a51c1bfd9"}, {file = "contourpy-1.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:189ceb1525eb0655ab8487a9a9c41f42a73ba52d6789754788d1883fb06b2d8a"}, @@ -559,7 +556,6 @@ files = [ {file = "contourpy-1.1.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:143dde50520a9f90e4a2703f367cf8ec96a73042b72e68fcd184e1279962eb6f"}, {file = "contourpy-1.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e94bef2580e25b5fdb183bf98a2faa2adc5b638736b2c0a4da98691da641316a"}, {file = "contourpy-1.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ed614aea8462735e7d70141374bd7650afd1c3f3cb0c2dbbcbe44e14331bf002"}, - {file = "contourpy-1.1.0-cp39-cp39-win32.whl", hash = "sha256:71551f9520f008b2950bef5f16b0e3587506ef4f23c734b71ffb7b89f8721999"}, {file = "contourpy-1.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:438ba416d02f82b692e371858143970ed2eb6337d9cdbbede0d8ad9f3d7dd17d"}, {file = "contourpy-1.1.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a698c6a7a432789e587168573a864a7ea374c6be8d4f31f9d87c001d5a843493"}, {file = "contourpy-1.1.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:397b0ac8a12880412da3551a8cb5a187d3298a72802b45a3bd1805e204ad8439"}, @@ -1306,11 +1302,11 @@ files = [ google-auth = ">=2.14.1,<3.0.dev0" googleapis-common-protos = ">=1.56.2,<2.0.dev0" grpcio = [ - {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, + {version = ">=1.33.2,<2.0dev", optional = true, markers = "extra == \"grpc\""}, {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, ] grpcio-status = [ - {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, + {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "extra == \"grpc\""}, {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, ] protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0" @@ -2478,16 +2474,6 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -3049,8 +3035,8 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, {version = ">=1.20.3", markers = "python_version < \"3.10\""}, + {version = ">=1.21.0", markers = "python_version >= \"3.10\""}, {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, ] python-dateutil = ">=2.8.1" @@ -3801,7 +3787,6 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, - {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -3809,15 +3794,8 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, - {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, - {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, - {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, - {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -3834,7 +3812,6 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, - {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -3842,7 +3819,6 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, - {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -4409,7 +4385,7 @@ files = [ ] [package.dependencies] -greenlet = {version = "!=0.4.17", markers = "python_version >= \"3\" and (platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\")"} +greenlet = {version = "!=0.4.17", markers = "python_version >= \"3\" and (platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\")"} [package.extras] aiomysql = ["aiomysql (>=0.2.0)", "greenlet (!=0.4.17)"] @@ -4499,8 +4475,8 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.22.3", markers = "python_version == \"3.10\" and platform_system == \"Windows\" and platform_python_implementation != \"PyPy\""}, {version = ">=1.18", markers = "python_version != \"3.10\" or platform_system != \"Windows\" or platform_python_implementation == \"PyPy\""}, + {version = ">=1.22.3", markers = "python_version == \"3.10\" and platform_system == \"Windows\" and platform_python_implementation != \"PyPy\""}, ] packaging = ">=21.3" pandas = ">=1.0" diff --git a/pyproject.toml b/pyproject.toml index 6e20c093e..e42abb123 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pandasai" -version = "1.4.3" +version = "1.4.4" description = "PandasAI is a Python library that integrates generative artificial intelligence capabilities into Pandas, making dataframes conversational." authors = ["Gabriele Venturi"] license = "MIT" diff --git a/tests/connectors/test_sql.py b/tests/connectors/test_sql.py index e9d55fbb7..83f2e6723 100644 --- a/tests/connectors/test_sql.py +++ b/tests/connectors/test_sql.py @@ -2,7 +2,8 @@ import pandas as pd from unittest.mock import Mock, patch from pandasai.connectors.base import SQLConnectorConfig -from pandasai.connectors.sql import SQLConnector +from pandasai.connectors.sql import PostgreSQLConnector, SQLConnector +from pandasai.exceptions import MaliciousQueryError class TestSQLConnector(unittest.TestCase): @@ -104,3 +105,92 @@ def test_fallback_name_property(self): # Test fallback_name property fallback_name = self.connector.fallback_name self.assertEqual(fallback_name, "your_table") + + def test_is_sql_query_safe_safe_query(self): + safe_query = "SELECT * FROM users WHERE username = 'John'" + result = self.connector._is_sql_query_safe(safe_query) + assert result is True + + def test_is_sql_query_safe_malicious_query(self): + malicious_query = "DROP TABLE users" + result = self.connector._is_sql_query_safe(malicious_query) + assert result is False + + @patch("pandasai.connectors.sql.pd.read_sql", autospec=True) + def test_execute_direct_sql_query_safe_query(self, mock_sql): + safe_query = "SELECT * FROM users WHERE username = 'John'" + expected_data = pd.DataFrame({"Column1": [1, 2, 3], "Column2": [4, 5, 6]}) + mock_sql.return_value = expected_data + result = self.connector.execute_direct_sql_query(safe_query) + assert isinstance(result, pd.DataFrame) + + def test_execute_direct_sql_query_malicious_query(self): + malicious_query = "DROP TABLE users" + try: + self.connector.execute_direct_sql_query(malicious_query) + assert False, "MaliciousQueryError not raised" + except MaliciousQueryError: + pass + + @patch("pandasai.connectors.SQLConnector._init_connection") + def test_equals_identical_configs(self, mock_init_connection): + # Define your ConnectorConfig instance here + self.config = SQLConnectorConfig( + dialect="mysql", + driver="pymysql", + username="your_username", + password="your_password", + host="your_host", + port=443, + database="your_database", + table="your_table", + where=[["column_name", "=", "value"]], + ).dict() + + # Create an instance of SQLConnector + connector_2 = SQLConnector(self.config) + + assert self.connector.equals(connector_2) + + @patch("pandasai.connectors.SQLConnector._load_connector_config") + @patch("pandasai.connectors.SQLConnector._init_connection") + def test_equals_different_configs( + self, mock_load_connector_config, mock_init_connection + ): + # Define your ConnectorConfig instance here + self.config = SQLConnectorConfig( + dialect="mysql", + driver="pymysql", + username="your_username_differ", + password="your_password", + host="your_host", + port=443, + database="your_database", + table="your_table", + where=[["column_name", "=", "value"]], + ).dict() + + # Create an instance of SQLConnector + connector_2 = SQLConnector(self.config) + + assert not self.connector.equals(connector_2) + + @patch("pandasai.connectors.SQLConnector._init_connection") + def test_equals_different_connector(self, mock_init_connection): + # Define your ConnectorConfig instance here + self.config = SQLConnectorConfig( + dialect="mysql", + driver="pymysql", + username="your_username_differ", + password="your_password", + host="your_host", + port=443, + database="your_database", + table="your_table", + where=[["column_name", "=", "value"]], + ).dict() + + # Create an instance of SQLConnector + connector_2 = PostgreSQLConnector(self.config) + + assert not self.connector.equals(connector_2) diff --git a/tests/prompts/test_generate_python_code_prompt.py b/tests/prompts/test_generate_python_code_prompt.py index 64d70fc21..5358e862a 100644 --- a/tests/prompts/test_generate_python_code_prompt.py +++ b/tests/prompts/test_generate_python_code_prompt.py @@ -90,8 +90,6 @@ def test_str_with_args( Question -{viz_library_type_hint} - This is the initial python function. Do not change the params. Given the context, use the right dataframes. ```python # TODO import all the dependencies required @@ -103,6 +101,7 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict: 1. Prepare: Preprocessing and cleaning data if necessary 2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.) 3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.) + {viz_library_type_hint} At the end, return a dictionary of: {output_type_hint} """ @@ -151,8 +150,6 @@ def test_advanced_reasoning_prompt(self): Question -{viz_library_type_hint} - This is the initial python function. Do not change the params. Given the context, use the right dataframes. ```python # TODO import all the dependencies required @@ -164,6 +161,7 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict: 1. Prepare: Preprocessing and cleaning data if necessary 2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.) 3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.) + {viz_library_type_hint} At the end, return a dictionary of: """ diff --git a/tests/prompts/test_sql_prompt.py b/tests/prompts/test_sql_prompt.py new file mode 100644 index 000000000..533a50371 --- /dev/null +++ b/tests/prompts/test_sql_prompt.py @@ -0,0 +1,115 @@ +"""Unit tests for the correct error prompt class""" +import sys + +import pandas as pd +import pytest +from pandasai import SmartDataframe +from pandasai.llm.fake import FakeLLM +from pandasai.prompts.direct_sql_prompt import DirectSQLPrompt +from pandasai.helpers.viz_library_types import ( + MatplotlibVizLibraryType, + viz_lib_map, + viz_lib_type_factory, +) +from pandasai.helpers.output_types import ( + output_type_factory, + DefaultOutputType, + output_types_map, +) + + +class TestDirectSqlPrompt: + """Unit tests for the correct error prompt class""" + + @pytest.mark.parametrize( + "save_charts_path,output_type_hint,viz_library_type_hint", + [ + ( + "exports/charts", + DefaultOutputType().template_hint, + MatplotlibVizLibraryType().template_hint, + ), + ( + "custom/dir/for/charts", + DefaultOutputType().template_hint, + MatplotlibVizLibraryType().template_hint, + ), + *[ + ( + "exports/charts", + output_type_factory(type_).template_hint, + viz_lib_type_factory(viz_type_).template_hint, + ) + for type_ in output_types_map + for viz_type_ in viz_lib_map + ], + ], + ) + def test_direct_sql_prompt_with_params( + self, save_charts_path, output_type_hint, viz_library_type_hint + ): + """Test that the __str__ method is implemented""" + + llm = FakeLLM("plt.show()") + dfs = [ + SmartDataframe( + pd.DataFrame({}), + config={"llm": llm}, + ) + ] + + prompt = DirectSQLPrompt(tables=dfs) + prompt.set_var("dfs", dfs) + prompt.set_var("conversation", "What is the correct code?") + prompt.set_var("output_type_hint", output_type_hint) + prompt.set_var("save_charts_path", save_charts_path) + prompt.set_var("viz_library_type", viz_library_type_hint) + prompt_content = prompt.to_string() + if sys.platform.startswith("win"): + prompt_content = prompt_content.replace("\r\n", "\n") + + assert ( + prompt_content + == f'''You are provided with the following samples of sql tables data: + + + + + +
+ + + +What is the correct code? + + +You are provided with following function that executes the sql query, + +def execute_sql_query(sql_query: str) -> pd.Dataframe +"""his method connect to the database executes the sql query and returns the dataframe""" + + +This is the initial python function. Do not change the params. + +```python +# TODO import all the dependencies required +import pandas as pd + +def analyze_data() -> dict: + """ + Analyze the data, using the provided dataframes (`dfs`). + 1. Prepare: generate sql query to get data for analysis (grouping, filtering, aggregating, etc.) + 2. Process: execute the query using execute method available to you which returns dataframe + 3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.) + {viz_library_type_hint} + At the end, return a dictionary of: + {output_type_hint} + """ +``` + +Take a deep breath and reason step-by-step. Act as a senior data analyst. +In the answer, you must never write the "technical" names of the tables. +Based on the last message in the conversation: + +- return the updated analyze_data function wrapped within `python `''' # noqa: E501 + ) diff --git a/tests/skills/test_skills.py b/tests/skills/test_skills.py index ed979951c..1e8ddcc50 100644 --- a/tests/skills/test_skills.py +++ b/tests/skills/test_skills.py @@ -71,8 +71,7 @@ def code_manager(self, smart_dataframe: SmartDataframe): @pytest.fixture def exec_context(self) -> MagicMock: - context = MagicMock(spec=CodeExecutionContext) - return context + return CodeExecutionContext(uuid.uuid4(), SkillsManager()) @pytest.fixture def agent(self, llm, sample_df): @@ -317,7 +316,7 @@ def test_run_prompt_without_skills(self, agent): ) def test_code_exec_with_skills_no_use( - self, code_manager: CodeManager, exec_context: MagicMock + self, code_manager: CodeManager, exec_context ): code = """def analyze_data(dfs): return {'type': 'number', 'value': 1 + 1}""" diff --git a/tests/test_codemanager.py b/tests/test_codemanager.py index 0df494429..f8af9de63 100644 --- a/tests/test_codemanager.py +++ b/tests/test_codemanager.py @@ -1,11 +1,13 @@ """Unit tests for the CodeManager class""" from typing import Optional from unittest.mock import MagicMock, Mock, patch +import uuid import pandas as pd import pytest from pandasai.exceptions import BadImportError, NoCodeFoundError +from pandasai.helpers.skills_manager import SkillsManager from pandasai.llm.fake import FakeLLM from pandasai.smart_dataframe import SmartDataframe @@ -73,8 +75,7 @@ def code_manager(self, smart_dataframe: SmartDataframe): @pytest.fixture def exec_context(self) -> MagicMock: - context = MagicMock(spec=CodeExecutionContext) - return context + return CodeExecutionContext(uuid.uuid4(), SkillsManager()) def test_run_code_for_calculations( self, code_manager: CodeManager, exec_context: MagicMock @@ -97,6 +98,8 @@ def test_clean_code_remove_builtins( builtins_code = """import set def analyze_data(dfs): return {'type': 'number', 'value': set([1, 2, 3])}""" + + exec_context._can_direct_sql = False assert code_manager.execute_code(builtins_code, exec_context)["value"] == { 1, 2, diff --git a/tests/test_smartdataframe.py b/tests/test_smartdataframe.py index b2f84c68e..9a6fa3a1a 100644 --- a/tests/test_smartdataframe.py +++ b/tests/test_smartdataframe.py @@ -219,10 +219,6 @@ def test_run_with_privacy_enforcement(self, llm): User: How many countries are in the dataframe? -When a user requests to create a chart, utilize the Python -matplotlib library to generate high-quality graphics that will be saved -directly to a file. - This is the initial python function. Do not change the params. Given the context, use the right dataframes. ```python # TODO import all the dependencies required @@ -234,6 +230,7 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict: 1. Prepare: Preprocessing and cleaning data if necessary 2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.) 3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.) + If the user requests to create a chart, utilize the Python matplotlib library to generate high-quality graphics that will be saved directly to a file. At the end, return a dictionary of: - type (possible values "string", "number", "dataframe", "plot") - value (can be a string, a dataframe or the path of the plot, NOT a dictionary) @@ -284,10 +281,6 @@ def test_run_passing_output_type(self, llm, output_type, output_type_hint): User: How many countries are in the dataframe? -When a user requests to create a chart, utilize the Python -matplotlib library to generate high-quality graphics that will be saved -directly to a file. - This is the initial python function. Do not change the params. Given the context, use the right dataframes. ```python # TODO import all the dependencies required @@ -299,6 +292,7 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict: 1. Prepare: Preprocessing and cleaning data if necessary 2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.) 3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.) + If the user requests to create a chart, utilize the Python matplotlib library to generate high-quality graphics that will be saved directly to a file. At the end, return a dictionary of: {output_type_hint} """ @@ -1084,8 +1078,6 @@ def test_run_passing_viz_library_type( User: Plot the histogram of countries showing for each the gdp with distinct bar colors -%s - This is the initial python function. Do not change the params. Given the context, use the right dataframes. ```python # TODO import all the dependencies required @@ -1097,6 +1089,7 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict: 1. Prepare: Preprocessing and cleaning data if necessary 2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.) 3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.) + %s At the end, return a dictionary of: - type (possible values "string", "number", "dataframe", "plot") - value (can be a string, a dataframe or the path of the plot, NOT a dictionary) diff --git a/tests/test_smartdatalake.py b/tests/test_smartdatalake.py index bb33936bc..3195a53de 100644 --- a/tests/test_smartdatalake.py +++ b/tests/test_smartdatalake.py @@ -9,12 +9,18 @@ import pytest from pandasai import SmartDataframe, SmartDatalake +from pandasai.connectors.base import SQLConnectorConfig +from pandasai.connectors.sql import PostgreSQLConnector, SQLConnector +from pandasai.exceptions import InvalidConfigError from pandasai.helpers.code_manager import CodeManager from pandasai.llm.fake import FakeLLM from pandasai.middlewares import Middleware from langchain import OpenAI +from pandasai.prompts.direct_sql_prompt import DirectSQLPrompt +from pandasai.prompts.generate_python_code import GeneratePythonCodePrompt + class TestSmartDatalake: """Unit tests for the SmartDatlake class""" @@ -66,6 +72,44 @@ def sample_df(self): } ) + @pytest.fixture + @patch("pandasai.connectors.sql.create_engine", autospec=True) + def sql_connector(self, create_engine): + # Define your ConnectorConfig instance here + self.config = SQLConnectorConfig( + dialect="mysql", + driver="pymysql", + username="your_username", + password="your_password", + host="your_host", + port=443, + database="your_database", + table="your_table", + where=[["column_name", "=", "value"]], + ).dict() + + # Create an instance of SQLConnector + return SQLConnector(self.config) + + @pytest.fixture + @patch("pandasai.connectors.sql.create_engine", autospec=True) + def pgsql_connector(self, create_engine): + # Define your ConnectorConfig instance here + self.config = SQLConnectorConfig( + dialect="mysql", + driver="pymysql", + username="your_username", + password="your_password", + host="your_host", + port=443, + database="your_database", + table="your_table", + where=[["column_name", "=", "value"]], + ).dict() + + # Create an instance of SQLConnector + return PostgreSQLConnector(self.config) + @pytest.fixture def smart_dataframe(self, llm, sample_df): return SmartDataframe(sample_df, config={"llm": llm, "enable_cache": False}) @@ -229,3 +273,59 @@ def analyze_data(dfs): smart_datalake.chat("How many countries are in the dataframe?") assert smart_datalake.last_answer == "Custom answer" assert smart_datalake.last_reasoning == "Custom reasoning" + + def test_get_chat_prompt(self, smart_datalake: SmartDatalake): + # Test case 1: direct_sql is True + smart_datalake._config.direct_sql = True + gen_key, gen_prompt = smart_datalake._get_chat_prompt() + expected_key = "direct_sql_prompt" + assert gen_key == expected_key + assert isinstance(gen_prompt, DirectSQLPrompt) + + # Test case 2: direct_sql is False + smart_datalake._config.direct_sql = False + gen_key, gen_prompt = smart_datalake._get_chat_prompt() + expected_key = "generate_python_code" + assert gen_key == expected_key + assert isinstance(gen_prompt, GeneratePythonCodePrompt) + + def test_validate_true_direct_sql_with_non_connector(self, llm, sample_df): + # raise exception with non connector + SmartDatalake( + [sample_df], + config={"llm": llm, "enable_cache": False, "direct_sql": True}, + ) + + def test_validate_direct_sql_with_connector(self, llm, sql_connector): + # not exception is raised using single connector + SmartDatalake( + [sql_connector], + config={"llm": llm, "enable_cache": False, "direct_sql": True}, + ) + + def test_validate_false_direct_sql_with_connector(self, llm, sql_connector): + # not exception is raised using single connector + SmartDatalake( + [sql_connector], + config={"llm": llm, "enable_cache": False, "direct_sql": False}, + ) + + def test_validate_false_direct_sql_with_two_different_connector( + self, llm, sql_connector, pgsql_connector + ): + # not exception is raised using single connector + SmartDatalake( + [sql_connector, pgsql_connector], + config={"llm": llm, "enable_cache": False, "direct_sql": False}, + ) + + def test_validate_true_direct_sql_with_two_different_connector( + self, llm, sql_connector, pgsql_connector + ): + # not exception is raised using single connector + # raise exception when two different connector + with pytest.raises(InvalidConfigError): + SmartDatalake( + [sql_connector, pgsql_connector], + config={"llm": llm, "enable_cache": False, "direct_sql": True}, + )